Make feature validation immutable. (#9388)

This commit is contained in:
Jiaming Yuan 2023-07-16 06:52:55 +08:00 committed by GitHub
parent 0a07900b9f
commit b342ef951b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1623,7 +1623,7 @@ class Booster:
)
for d in cache:
# Validate feature only after the feature names are saved into booster.
self._validate_dmatrix_features(d)
self._assign_dmatrix_features(d)
if isinstance(model_file, Booster):
assert self.handle is not None
@ -1746,6 +1746,11 @@ class Booster:
self.__dict__.update(state)
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
"""Get a slice of the tree-based model.
.. versionadded:: 1.3.0
"""
if isinstance(val, int):
val = slice(val, val + 1)
if isinstance(val, tuple):
@ -1784,6 +1789,11 @@ class Booster:
return sliced
def __iter__(self) -> Generator["Booster", None, None]:
"""Iterator method for getting individual trees.
.. versionadded:: 2.0.0
"""
for i in range(0, self.num_boosted_rounds()):
yield self[i]
@ -1994,7 +2004,7 @@ class Booster:
"""
if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_dmatrix_features(dtrain)
self._assign_dmatrix_features(dtrain)
if fobj is None:
_check_call(
@ -2026,7 +2036,7 @@ class Booster:
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}")
if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_dmatrix_features(dtrain)
self._assign_dmatrix_features(dtrain)
_check_call(
_LIB.XGBoosterBoostOneIter(
@ -2067,7 +2077,7 @@ class Booster:
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
if not isinstance(d[1], str):
raise TypeError(f"expected string, got {type(d[1]).__name__}")
self._validate_dmatrix_features(d[0])
self._assign_dmatrix_features(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
@ -2119,7 +2129,7 @@ class Booster:
result: str
Evaluation result string.
"""
self._validate_dmatrix_features(data)
self._assign_dmatrix_features(data)
return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args
@ -2218,7 +2228,8 @@ class Booster:
if not isinstance(data, DMatrix):
raise TypeError("Expecting data to be a DMatrix object, got: ", type(data))
if validate_features:
self._validate_dmatrix_features(data)
fn = data.feature_names
self._validate_features(fn)
args = {
"type": 0,
"training": training,
@ -2843,14 +2854,13 @@ class Booster:
# pylint: disable=no-member
return df.sort(["Tree", "Node"]).reset_index(drop=True)
def _validate_dmatrix_features(self, data: DMatrix) -> None:
def _assign_dmatrix_features(self, data: DMatrix) -> None:
if data.num_row() == 0:
return
fn = data.feature_names
ft = data.feature_types
# Be consistent with versions before 1.7, "validate" actually modifies the
# booster.
if self.feature_names is None:
self.feature_names = fn
if self.feature_types is None: