Make feature validation immutable. (#9388)

This commit is contained in:
Jiaming Yuan 2023-07-16 06:52:55 +08:00 committed by GitHub
parent 0a07900b9f
commit b342ef951b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1623,7 +1623,7 @@ class Booster:
) )
for d in cache: for d in cache:
# Validate feature only after the feature names are saved into booster. # Validate feature only after the feature names are saved into booster.
self._validate_dmatrix_features(d) self._assign_dmatrix_features(d)
if isinstance(model_file, Booster): if isinstance(model_file, Booster):
assert self.handle is not None assert self.handle is not None
@ -1746,6 +1746,11 @@ class Booster:
self.__dict__.update(state) self.__dict__.update(state)
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster": def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
"""Get a slice of the tree-based model.
.. versionadded:: 1.3.0
"""
if isinstance(val, int): if isinstance(val, int):
val = slice(val, val + 1) val = slice(val, val + 1)
if isinstance(val, tuple): if isinstance(val, tuple):
@ -1784,6 +1789,11 @@ class Booster:
return sliced return sliced
def __iter__(self) -> Generator["Booster", None, None]: def __iter__(self) -> Generator["Booster", None, None]:
"""Iterator method for getting individual trees.
.. versionadded:: 2.0.0
"""
for i in range(0, self.num_boosted_rounds()): for i in range(0, self.num_boosted_rounds()):
yield self[i] yield self[i]
@ -1994,7 +2004,7 @@ class Booster:
""" """
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}") raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_dmatrix_features(dtrain) self._assign_dmatrix_features(dtrain)
if fobj is None: if fobj is None:
_check_call( _check_call(
@ -2026,7 +2036,7 @@ class Booster:
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}") raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}")
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}") raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_dmatrix_features(dtrain) self._assign_dmatrix_features(dtrain)
_check_call( _check_call(
_LIB.XGBoosterBoostOneIter( _LIB.XGBoosterBoostOneIter(
@ -2067,7 +2077,7 @@ class Booster:
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}") raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
if not isinstance(d[1], str): if not isinstance(d[1], str):
raise TypeError(f"expected string, got {type(d[1]).__name__}") raise TypeError(f"expected string, got {type(d[1]).__name__}")
self._validate_dmatrix_features(d[0]) self._assign_dmatrix_features(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
@ -2119,7 +2129,7 @@ class Booster:
result: str result: str
Evaluation result string. Evaluation result string.
""" """
self._validate_dmatrix_features(data) self._assign_dmatrix_features(data)
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args # pylint: disable=too-many-function-args
@ -2218,7 +2228,8 @@ class Booster:
if not isinstance(data, DMatrix): if not isinstance(data, DMatrix):
raise TypeError("Expecting data to be a DMatrix object, got: ", type(data)) raise TypeError("Expecting data to be a DMatrix object, got: ", type(data))
if validate_features: if validate_features:
self._validate_dmatrix_features(data) fn = data.feature_names
self._validate_features(fn)
args = { args = {
"type": 0, "type": 0,
"training": training, "training": training,
@ -2843,14 +2854,13 @@ class Booster:
# pylint: disable=no-member # pylint: disable=no-member
return df.sort(["Tree", "Node"]).reset_index(drop=True) return df.sort(["Tree", "Node"]).reset_index(drop=True)
def _validate_dmatrix_features(self, data: DMatrix) -> None: def _assign_dmatrix_features(self, data: DMatrix) -> None:
if data.num_row() == 0: if data.num_row() == 0:
return return
fn = data.feature_names fn = data.feature_names
ft = data.feature_types ft = data.feature_types
# Be consistent with versions before 1.7, "validate" actually modifies the
# booster.
if self.feature_names is None: if self.feature_names is None:
self.feature_names = fn self.feature_names = fn
if self.feature_types is None: if self.feature_types is None: