[breaking] Remove support for single string feature info. (#9401)

- Input must be a sequence of strings.
- Improve validation error message.
This commit is contained in:
Jiaming Yuan 2023-07-24 11:06:30 +08:00 committed by GitHub
parent 275da176ba
commit 01e00efc53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 104 deletions

View File

@ -297,6 +297,23 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
)
def _validate_feature_info(
feature_info: Sequence[str], n_features: int, name: str
) -> List[str]:
if isinstance(feature_info, str) or not isinstance(feature_info, Sequence):
raise TypeError(
f"Expecting a sequence of strings for {name}, got: {type(feature_info)}"
)
feature_info = list(feature_info)
if len(feature_info) != n_features and n_features != 0:
msg = (
f"{name} must have the same length as the number of data columns, ",
f"expected {n_features}, got {len(feature_info)}",
)
raise ValueError(msg)
return feature_info
def build_info() -> dict:
"""Build information of XGBoost. The returned value format is not stable. Also,
please note that build time dependency is not the same as runtime dependency. For
@ -1217,11 +1234,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
@property
def feature_names(self) -> Optional[FeatureNames]:
"""Get feature names (column labels).
"""Labels for features (column labels).
Setting it to ``None`` resets existing feature names.
Returns
-------
feature_names : list or None
"""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
@ -1240,67 +1256,61 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
@feature_names.setter
def feature_names(self, feature_names: Optional[FeatureNames]) -> None:
"""Set feature names (column labels).
Parameters
----------
feature_names : list or None
Labels for features. None will reset existing feature names
"""
if feature_names is not None:
# validate feature name
try:
if not isinstance(feature_names, str):
feature_names = list(feature_names)
else:
feature_names = [feature_names]
except TypeError:
feature_names = [cast(str, feature_names)]
if len(feature_names) != len(set(feature_names)):
raise ValueError("feature_names must be unique")
if len(feature_names) != self.num_col() and self.num_col() != 0:
msg = (
"feature_names must have the same length as data, ",
f"expected {self.num_col()}, got {len(feature_names)}",
)
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []<
if not all(
isinstance(f, str) and not any(x in f for x in ["[", "]", "<"])
for f in feature_names
):
raise ValueError(
"feature_names must be string, and may not contain [, ] or <"
)
feature_names_bytes = [bytes(f, encoding="utf-8") for f in feature_names]
c_feature_names = (ctypes.c_char_p * len(feature_names_bytes))(
*feature_names_bytes
)
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle,
c_str("feature_name"),
c_feature_names,
c_bst_ulong(len(feature_names)),
)
)
else:
# reset feature_types also
if feature_names is None:
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle, c_str("feature_name"), None, c_bst_ulong(0)
)
)
self.feature_types = None
return
# validate feature name
feature_names = _validate_feature_info(
feature_names, self.num_col(), "feature names"
)
if len(feature_names) != len(set(feature_names)):
values, counts = np.unique(
feature_names,
return_index=False,
return_inverse=False,
return_counts=True,
)
duplicates = [name for name, cnt in zip(values, counts) if cnt > 1]
raise ValueError(
f"feature_names must be unique. Duplicates found: {duplicates}"
)
# prohibit the use symbols that may affect parsing. e.g. []<
if not all(
isinstance(f, str) and not any(x in f for x in ["[", "]", "<"])
for f in feature_names
):
raise ValueError(
"feature_names must be string, and may not contain [, ] or <"
)
feature_names_bytes = [bytes(f, encoding="utf-8") for f in feature_names]
c_feature_names = (ctypes.c_char_p * len(feature_names_bytes))(
*feature_names_bytes
)
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle,
c_str("feature_name"),
c_feature_names,
c_bst_ulong(len(feature_names)),
)
)
@property
def feature_types(self) -> Optional[FeatureTypes]:
"""Get feature types (column types).
"""Type of features (column types).
This is for displaying the results and categorical data support. See
:py:class:`DMatrix` for details.
Setting it to ``None`` resets existing feature types.
Returns
-------
feature_types : list or None
"""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
@ -1318,57 +1328,32 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
return res
@feature_types.setter
def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None:
"""Set feature types (column types).
This is for displaying the results and categorical data support. See
:py:class:`DMatrix` for details.
Parameters
----------
feature_types :
Labels for features. None will reset existing feature names
"""
# For compatibility reason this function wraps single str input into a list. But
# we should not promote such usage since other than visualization, the field is
# also used for specifying categorical data type.
if feature_types is not None:
if not isinstance(feature_types, (list, str)):
raise TypeError("feature_types must be string or list of strings")
if isinstance(feature_types, str):
# single string will be applied to all columns
feature_types = [feature_types] * self.num_col()
try:
if not isinstance(feature_types, str):
feature_types = list(feature_types)
else:
feature_types = [feature_types]
except TypeError:
feature_types = [cast(str, feature_types)]
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]
c_feature_types = (ctypes.c_char_p * len(feature_types_bytes))(
*feature_types_bytes
)
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle,
c_str("feature_type"),
c_feature_types,
c_bst_ulong(len(feature_types)),
)
)
if len(feature_types) != self.num_col() and self.num_col() != 0:
msg = "feature_types must have the same length as data"
raise ValueError(msg)
else:
# Reset.
def feature_types(self, feature_types: Optional[FeatureTypes]) -> None:
if feature_types is None:
# Reset
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle, c_str("feature_type"), None, c_bst_ulong(0)
)
)
return
feature_types = _validate_feature_info(
feature_types, self.num_col(), "feature types"
)
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]
c_feature_types = (ctypes.c_char_p * len(feature_types_bytes))(
*feature_types_bytes
)
_check_call(
_LIB.XGDMatrixSetStrFeatureInfo(
self.handle,
c_str("feature_type"),
c_feature_types,
c_bst_ulong(len(feature_types)),
)
)
class _ProxyDMatrix(DMatrix):

View File

@ -219,8 +219,8 @@ class TestDMatrix:
assert dm.slice([0, 1]).num_col() == dm.num_col()
assert dm.slice([0, 1]).feature_names == dm.feature_names
dm.feature_types = 'q'
assert dm.feature_types == list('qqqqq')
with pytest.raises(ValueError, match=r"Duplicates found: \['bar'\]"):
dm.feature_names = ["bar"] * (data.shape[1] - 2) + ["a", "b"]
dm.feature_types = list('qiqiq')
assert dm.feature_types == list('qiqiq')
@ -230,6 +230,7 @@ class TestDMatrix:
# reset
dm.feature_names = None
dm.feature_types = None
assert dm.feature_names is None
assert dm.feature_types is None