diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 4cacd61f3..70ef3535d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -297,6 +297,23 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None: ) +def _validate_feature_info( + feature_info: Sequence[str], n_features: int, name: str +) -> List[str]: + if isinstance(feature_info, str) or not isinstance(feature_info, Sequence): + raise TypeError( + f"Expecting a sequence of strings for {name}, got: {type(feature_info)}" + ) + feature_info = list(feature_info) + if len(feature_info) != n_features and n_features != 0: + msg = ( + f"{name} must have the same length as the number of data columns, ", + f"expected {n_features}, got {len(feature_info)}", + ) + raise ValueError(msg) + return feature_info + + def build_info() -> dict: """Build information of XGBoost. The returned value format is not stable. Also, please note that build time dependency is not the same as runtime dependency. For @@ -1217,11 +1234,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m @property def feature_names(self) -> Optional[FeatureNames]: - """Get feature names (column labels). + """Labels for features (column labels). + + Setting it to ``None`` resets existing feature names. - Returns - ------- - feature_names : list or None """ length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() @@ -1240,67 +1256,61 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m @feature_names.setter def feature_names(self, feature_names: Optional[FeatureNames]) -> None: - """Set feature names (column labels). - - Parameters - ---------- - feature_names : list or None - Labels for features. None will reset existing feature names - """ - if feature_names is not None: - # validate feature name - try: - if not isinstance(feature_names, str): - feature_names = list(feature_names) - else: - feature_names = [feature_names] - except TypeError: - feature_names = [cast(str, feature_names)] - - if len(feature_names) != len(set(feature_names)): - raise ValueError("feature_names must be unique") - if len(feature_names) != self.num_col() and self.num_col() != 0: - msg = ( - "feature_names must have the same length as data, ", - f"expected {self.num_col()}, got {len(feature_names)}", - ) - raise ValueError(msg) - # prohibit to use symbols may affect to parse. e.g. []< - if not all( - isinstance(f, str) and not any(x in f for x in ["[", "]", "<"]) - for f in feature_names - ): - raise ValueError( - "feature_names must be string, and may not contain [, ] or <" - ) - feature_names_bytes = [bytes(f, encoding="utf-8") for f in feature_names] - c_feature_names = (ctypes.c_char_p * len(feature_names_bytes))( - *feature_names_bytes - ) - _check_call( - _LIB.XGDMatrixSetStrFeatureInfo( - self.handle, - c_str("feature_name"), - c_feature_names, - c_bst_ulong(len(feature_names)), - ) - ) - else: - # reset feature_types also + if feature_names is None: _check_call( _LIB.XGDMatrixSetStrFeatureInfo( self.handle, c_str("feature_name"), None, c_bst_ulong(0) ) ) - self.feature_types = None + return + + # validate feature name + feature_names = _validate_feature_info( + feature_names, self.num_col(), "feature names" + ) + if len(feature_names) != len(set(feature_names)): + values, counts = np.unique( + feature_names, + return_index=False, + return_inverse=False, + return_counts=True, + ) + duplicates = [name for name, cnt in zip(values, counts) if cnt > 1] + raise ValueError( + f"feature_names must be unique. Duplicates found: {duplicates}" + ) + + # prohibit the use symbols that may affect parsing. e.g. []< + if not all( + isinstance(f, str) and not any(x in f for x in ["[", "]", "<"]) + for f in feature_names + ): + raise ValueError( + "feature_names must be string, and may not contain [, ] or <" + ) + + feature_names_bytes = [bytes(f, encoding="utf-8") for f in feature_names] + c_feature_names = (ctypes.c_char_p * len(feature_names_bytes))( + *feature_names_bytes + ) + _check_call( + _LIB.XGDMatrixSetStrFeatureInfo( + self.handle, + c_str("feature_name"), + c_feature_names, + c_bst_ulong(len(feature_names)), + ) + ) @property def feature_types(self) -> Optional[FeatureTypes]: - """Get feature types (column types). + """Type of features (column types). + + This is for displaying the results and categorical data support. See + :py:class:`DMatrix` for details. + + Setting it to ``None`` resets existing feature types. - Returns - ------- - feature_types : list or None """ length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() @@ -1318,57 +1328,32 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m return res @feature_types.setter - def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: - """Set feature types (column types). - - This is for displaying the results and categorical data support. See - :py:class:`DMatrix` for details. - - Parameters - ---------- - feature_types : - Labels for features. None will reset existing feature names - - """ - # For compatibility reason this function wraps single str input into a list. But - # we should not promote such usage since other than visualization, the field is - # also used for specifying categorical data type. - if feature_types is not None: - if not isinstance(feature_types, (list, str)): - raise TypeError("feature_types must be string or list of strings") - if isinstance(feature_types, str): - # single string will be applied to all columns - feature_types = [feature_types] * self.num_col() - try: - if not isinstance(feature_types, str): - feature_types = list(feature_types) - else: - feature_types = [feature_types] - except TypeError: - feature_types = [cast(str, feature_types)] - feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types] - c_feature_types = (ctypes.c_char_p * len(feature_types_bytes))( - *feature_types_bytes - ) - _check_call( - _LIB.XGDMatrixSetStrFeatureInfo( - self.handle, - c_str("feature_type"), - c_feature_types, - c_bst_ulong(len(feature_types)), - ) - ) - - if len(feature_types) != self.num_col() and self.num_col() != 0: - msg = "feature_types must have the same length as data" - raise ValueError(msg) - else: - # Reset. + def feature_types(self, feature_types: Optional[FeatureTypes]) -> None: + if feature_types is None: + # Reset _check_call( _LIB.XGDMatrixSetStrFeatureInfo( self.handle, c_str("feature_type"), None, c_bst_ulong(0) ) ) + return + + feature_types = _validate_feature_info( + feature_types, self.num_col(), "feature types" + ) + + feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types] + c_feature_types = (ctypes.c_char_p * len(feature_types_bytes))( + *feature_types_bytes + ) + _check_call( + _LIB.XGDMatrixSetStrFeatureInfo( + self.handle, + c_str("feature_type"), + c_feature_types, + c_bst_ulong(len(feature_types)), + ) + ) class _ProxyDMatrix(DMatrix): diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index bcc089afb..73e2055b7 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -219,8 +219,8 @@ class TestDMatrix: assert dm.slice([0, 1]).num_col() == dm.num_col() assert dm.slice([0, 1]).feature_names == dm.feature_names - dm.feature_types = 'q' - assert dm.feature_types == list('qqqqq') + with pytest.raises(ValueError, match=r"Duplicates found: \['bar'\]"): + dm.feature_names = ["bar"] * (data.shape[1] - 2) + ["a", "b"] dm.feature_types = list('qiqiq') assert dm.feature_types == list('qiqiq') @@ -230,6 +230,7 @@ class TestDMatrix: # reset dm.feature_names = None + dm.feature_types = None assert dm.feature_names is None assert dm.feature_types is None