Make feature validation immutable. (#9388)
This commit is contained in:
parent
0a07900b9f
commit
b342ef951b
@ -1623,7 +1623,7 @@ class Booster:
|
||||
)
|
||||
for d in cache:
|
||||
# Validate feature only after the feature names are saved into booster.
|
||||
self._validate_dmatrix_features(d)
|
||||
self._assign_dmatrix_features(d)
|
||||
|
||||
if isinstance(model_file, Booster):
|
||||
assert self.handle is not None
|
||||
@ -1746,6 +1746,11 @@ class Booster:
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
||||
"""Get a slice of the tree-based model.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
"""
|
||||
if isinstance(val, int):
|
||||
val = slice(val, val + 1)
|
||||
if isinstance(val, tuple):
|
||||
@ -1784,6 +1789,11 @@ class Booster:
|
||||
return sliced
|
||||
|
||||
def __iter__(self) -> Generator["Booster", None, None]:
|
||||
"""Iterator method for getting individual trees.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
"""
|
||||
for i in range(0, self.num_boosted_rounds()):
|
||||
yield self[i]
|
||||
|
||||
@ -1994,7 +2004,7 @@ class Booster:
|
||||
"""
|
||||
if not isinstance(dtrain, DMatrix):
|
||||
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||
self._validate_dmatrix_features(dtrain)
|
||||
self._assign_dmatrix_features(dtrain)
|
||||
|
||||
if fobj is None:
|
||||
_check_call(
|
||||
@ -2026,7 +2036,7 @@ class Booster:
|
||||
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}")
|
||||
if not isinstance(dtrain, DMatrix):
|
||||
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||
self._validate_dmatrix_features(dtrain)
|
||||
self._assign_dmatrix_features(dtrain)
|
||||
|
||||
_check_call(
|
||||
_LIB.XGBoosterBoostOneIter(
|
||||
@ -2067,7 +2077,7 @@ class Booster:
|
||||
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
|
||||
if not isinstance(d[1], str):
|
||||
raise TypeError(f"expected string, got {type(d[1]).__name__}")
|
||||
self._validate_dmatrix_features(d[0])
|
||||
self._assign_dmatrix_features(d[0])
|
||||
|
||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||
@ -2119,7 +2129,7 @@ class Booster:
|
||||
result: str
|
||||
Evaluation result string.
|
||||
"""
|
||||
self._validate_dmatrix_features(data)
|
||||
self._assign_dmatrix_features(data)
|
||||
return self.eval_set([(data, name)], iteration)
|
||||
|
||||
# pylint: disable=too-many-function-args
|
||||
@ -2218,7 +2228,8 @@ class Booster:
|
||||
if not isinstance(data, DMatrix):
|
||||
raise TypeError("Expecting data to be a DMatrix object, got: ", type(data))
|
||||
if validate_features:
|
||||
self._validate_dmatrix_features(data)
|
||||
fn = data.feature_names
|
||||
self._validate_features(fn)
|
||||
args = {
|
||||
"type": 0,
|
||||
"training": training,
|
||||
@ -2843,14 +2854,13 @@ class Booster:
|
||||
# pylint: disable=no-member
|
||||
return df.sort(["Tree", "Node"]).reset_index(drop=True)
|
||||
|
||||
def _validate_dmatrix_features(self, data: DMatrix) -> None:
|
||||
def _assign_dmatrix_features(self, data: DMatrix) -> None:
|
||||
if data.num_row() == 0:
|
||||
return
|
||||
|
||||
fn = data.feature_names
|
||||
ft = data.feature_types
|
||||
# Be consistent with versions before 1.7, "validate" actually modifies the
|
||||
# booster.
|
||||
|
||||
if self.feature_names is None:
|
||||
self.feature_names = fn
|
||||
if self.feature_types is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user