[breaking] Remove support for single string feature info. (#9401)

- Input must be a sequence of strings.
- Improve validation error message.
This commit is contained in:
Jiaming Yuan 2023-07-24 11:06:30 +08:00 committed by GitHub
parent 275da176ba
commit 01e00efc53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 104 deletions

View File

@ -297,6 +297,23 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
) )
def _validate_feature_info(
feature_info: Sequence[str], n_features: int, name: str
) -> List[str]:
if isinstance(feature_info, str) or not isinstance(feature_info, Sequence):
raise TypeError(
f"Expecting a sequence of strings for {name}, got: {type(feature_info)}"
)
feature_info = list(feature_info)
if len(feature_info) != n_features and n_features != 0:
msg = (
f"{name} must have the same length as the number of data columns, ",
f"expected {n_features}, got {len(feature_info)}",
)
raise ValueError(msg)
return feature_info
def build_info() -> dict: def build_info() -> dict:
"""Build information of XGBoost. The returned value format is not stable. Also, """Build information of XGBoost. The returned value format is not stable. Also,
please note that build time dependency is not the same as runtime dependency. For please note that build time dependency is not the same as runtime dependency. For
@ -1217,11 +1234,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property @property
def feature_names(self) -> Optional[FeatureNames]: def feature_names(self) -> Optional[FeatureNames]:
"""Get feature names (column labels). """Labels for features (column labels).
Setting it to ``None`` resets existing feature names.
Returns
-------
feature_names : list or None
""" """
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
@ -1240,32 +1256,31 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
@feature_names.setter @feature_names.setter
def feature_names(self, feature_names: Optional[FeatureNames]) -> None: def feature_names(self, feature_names: Optional[FeatureNames]) -> None:
"""Set feature names (column labels). if feature_names is None:
_check_call(
Parameters _LIB.XGDMatrixSetStrFeatureInfo(
---------- self.handle, c_str("feature_name"), None, c_bst_ulong(0)
feature_names : list or None
Labels for features. None will reset existing feature names
"""
if feature_names is not None:
# validate feature name
try:
if not isinstance(feature_names, str):
feature_names = list(feature_names)
else:
feature_names = [feature_names]
except TypeError:
feature_names = [cast(str, feature_names)]
if len(feature_names) != len(set(feature_names)):
raise ValueError("feature_names must be unique")
if len(feature_names) != self.num_col() and self.num_col() != 0:
msg = (
"feature_names must have the same length as data, ",
f"expected {self.num_col()}, got {len(feature_names)}",
) )
raise ValueError(msg) )
# prohibit to use symbols may affect to parse. e.g. []< return
# validate feature name
feature_names = _validate_feature_info(
feature_names, self.num_col(), "feature names"
)
if len(feature_names) != len(set(feature_names)):
values, counts = np.unique(
feature_names,
return_index=False,
return_inverse=False,
return_counts=True,
)
duplicates = [name for name, cnt in zip(values, counts) if cnt > 1]
raise ValueError(
f"feature_names must be unique. Duplicates found: {duplicates}"
)
# prohibit the use symbols that may affect parsing. e.g. []<
if not all( if not all(
isinstance(f, str) and not any(x in f for x in ["[", "]", "<"]) isinstance(f, str) and not any(x in f for x in ["[", "]", "<"])
for f in feature_names for f in feature_names
@ -1273,6 +1288,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
raise ValueError( raise ValueError(
"feature_names must be string, and may not contain [, ] or <" "feature_names must be string, and may not contain [, ] or <"
) )
feature_names_bytes = [bytes(f, encoding="utf-8") for f in feature_names] feature_names_bytes = [bytes(f, encoding="utf-8") for f in feature_names]
c_feature_names = (ctypes.c_char_p * len(feature_names_bytes))( c_feature_names = (ctypes.c_char_p * len(feature_names_bytes))(
*feature_names_bytes *feature_names_bytes
@ -1285,22 +1301,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
c_bst_ulong(len(feature_names)), c_bst_ulong(len(feature_names)),
) )
) )
else:
# reset feature_types also
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle, c_str("feature_name"), None, c_bst_ulong(0)
)
)
self.feature_types = None
@property @property
def feature_types(self) -> Optional[FeatureTypes]: def feature_types(self) -> Optional[FeatureTypes]:
"""Get feature types (column types). """Type of features (column types).
This is for displaying the results and categorical data support. See
:py:class:`DMatrix` for details.
Setting it to ``None`` resets existing feature types.
Returns
-------
feature_types : list or None
""" """
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
@ -1318,34 +1328,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
return res return res
@feature_types.setter @feature_types.setter
def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: def feature_types(self, feature_types: Optional[FeatureTypes]) -> None:
"""Set feature types (column types). if feature_types is None:
# Reset
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle, c_str("feature_type"), None, c_bst_ulong(0)
)
)
return
This is for displaying the results and categorical data support. See feature_types = _validate_feature_info(
:py:class:`DMatrix` for details. feature_types, self.num_col(), "feature types"
)
Parameters
----------
feature_types :
Labels for features. None will reset existing feature names
"""
# For compatibility reason this function wraps single str input into a list. But
# we should not promote such usage since other than visualization, the field is
# also used for specifying categorical data type.
if feature_types is not None:
if not isinstance(feature_types, (list, str)):
raise TypeError("feature_types must be string or list of strings")
if isinstance(feature_types, str):
# single string will be applied to all columns
feature_types = [feature_types] * self.num_col()
try:
if not isinstance(feature_types, str):
feature_types = list(feature_types)
else:
feature_types = [feature_types]
except TypeError:
feature_types = [cast(str, feature_types)]
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types] feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]
c_feature_types = (ctypes.c_char_p * len(feature_types_bytes))( c_feature_types = (ctypes.c_char_p * len(feature_types_bytes))(
*feature_types_bytes *feature_types_bytes
@ -1359,17 +1355,6 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
) )
) )
if len(feature_types) != self.num_col() and self.num_col() != 0:
msg = "feature_types must have the same length as data"
raise ValueError(msg)
else:
# Reset.
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle, c_str("feature_type"), None, c_bst_ulong(0)
)
)
class _ProxyDMatrix(DMatrix): class _ProxyDMatrix(DMatrix):
"""A placeholder class when DMatrix cannot be constructed (QuantileDMatrix, """A placeholder class when DMatrix cannot be constructed (QuantileDMatrix,

View File

@ -219,8 +219,8 @@ class TestDMatrix:
assert dm.slice([0, 1]).num_col() == dm.num_col() assert dm.slice([0, 1]).num_col() == dm.num_col()
assert dm.slice([0, 1]).feature_names == dm.feature_names assert dm.slice([0, 1]).feature_names == dm.feature_names
dm.feature_types = 'q' with pytest.raises(ValueError, match=r"Duplicates found: \['bar'\]"):
assert dm.feature_types == list('qqqqq') dm.feature_names = ["bar"] * (data.shape[1] - 2) + ["a", "b"]
dm.feature_types = list('qiqiq') dm.feature_types = list('qiqiq')
assert dm.feature_types == list('qiqiq') assert dm.feature_types == list('qiqiq')
@ -230,6 +230,7 @@ class TestDMatrix:
# reset # reset
dm.feature_names = None dm.feature_names = None
dm.feature_types = None
assert dm.feature_names is None assert dm.feature_names is None
assert dm.feature_types is None assert dm.feature_types is None