Make feature validation immutable. (#9388)
This commit is contained in:
parent
0a07900b9f
commit
b342ef951b
@ -1623,7 +1623,7 @@ class Booster:
|
|||||||
)
|
)
|
||||||
for d in cache:
|
for d in cache:
|
||||||
# Validate feature only after the feature names are saved into booster.
|
# Validate feature only after the feature names are saved into booster.
|
||||||
self._validate_dmatrix_features(d)
|
self._assign_dmatrix_features(d)
|
||||||
|
|
||||||
if isinstance(model_file, Booster):
|
if isinstance(model_file, Booster):
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
@ -1746,6 +1746,11 @@ class Booster:
|
|||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
|
||||||
|
"""Get a slice of the tree-based model.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
"""
|
||||||
if isinstance(val, int):
|
if isinstance(val, int):
|
||||||
val = slice(val, val + 1)
|
val = slice(val, val + 1)
|
||||||
if isinstance(val, tuple):
|
if isinstance(val, tuple):
|
||||||
@ -1784,6 +1789,11 @@ class Booster:
|
|||||||
return sliced
|
return sliced
|
||||||
|
|
||||||
def __iter__(self) -> Generator["Booster", None, None]:
|
def __iter__(self) -> Generator["Booster", None, None]:
|
||||||
|
"""Iterator method for getting individual trees.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0.0
|
||||||
|
|
||||||
|
"""
|
||||||
for i in range(0, self.num_boosted_rounds()):
|
for i in range(0, self.num_boosted_rounds()):
|
||||||
yield self[i]
|
yield self[i]
|
||||||
|
|
||||||
@ -1994,7 +2004,7 @@ class Booster:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||||
self._validate_dmatrix_features(dtrain)
|
self._assign_dmatrix_features(dtrain)
|
||||||
|
|
||||||
if fobj is None:
|
if fobj is None:
|
||||||
_check_call(
|
_check_call(
|
||||||
@ -2026,7 +2036,7 @@ class Booster:
|
|||||||
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}")
|
raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}")
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||||
self._validate_dmatrix_features(dtrain)
|
self._assign_dmatrix_features(dtrain)
|
||||||
|
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterBoostOneIter(
|
_LIB.XGBoosterBoostOneIter(
|
||||||
@ -2067,7 +2077,7 @@ class Booster:
|
|||||||
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
|
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
|
||||||
if not isinstance(d[1], str):
|
if not isinstance(d[1], str):
|
||||||
raise TypeError(f"expected string, got {type(d[1]).__name__}")
|
raise TypeError(f"expected string, got {type(d[1]).__name__}")
|
||||||
self._validate_dmatrix_features(d[0])
|
self._assign_dmatrix_features(d[0])
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||||
@ -2119,7 +2129,7 @@ class Booster:
|
|||||||
result: str
|
result: str
|
||||||
Evaluation result string.
|
Evaluation result string.
|
||||||
"""
|
"""
|
||||||
self._validate_dmatrix_features(data)
|
self._assign_dmatrix_features(data)
|
||||||
return self.eval_set([(data, name)], iteration)
|
return self.eval_set([(data, name)], iteration)
|
||||||
|
|
||||||
# pylint: disable=too-many-function-args
|
# pylint: disable=too-many-function-args
|
||||||
@ -2218,7 +2228,8 @@ class Booster:
|
|||||||
if not isinstance(data, DMatrix):
|
if not isinstance(data, DMatrix):
|
||||||
raise TypeError("Expecting data to be a DMatrix object, got: ", type(data))
|
raise TypeError("Expecting data to be a DMatrix object, got: ", type(data))
|
||||||
if validate_features:
|
if validate_features:
|
||||||
self._validate_dmatrix_features(data)
|
fn = data.feature_names
|
||||||
|
self._validate_features(fn)
|
||||||
args = {
|
args = {
|
||||||
"type": 0,
|
"type": 0,
|
||||||
"training": training,
|
"training": training,
|
||||||
@ -2843,14 +2854,13 @@ class Booster:
|
|||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
return df.sort(["Tree", "Node"]).reset_index(drop=True)
|
return df.sort(["Tree", "Node"]).reset_index(drop=True)
|
||||||
|
|
||||||
def _validate_dmatrix_features(self, data: DMatrix) -> None:
|
def _assign_dmatrix_features(self, data: DMatrix) -> None:
|
||||||
if data.num_row() == 0:
|
if data.num_row() == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
fn = data.feature_names
|
fn = data.feature_names
|
||||||
ft = data.feature_types
|
ft = data.feature_types
|
||||||
# Be consistent with versions before 1.7, "validate" actually modifies the
|
|
||||||
# booster.
|
|
||||||
if self.feature_names is None:
|
if self.feature_names is None:
|
||||||
self.feature_names = fn
|
self.feature_names = fn
|
||||||
if self.feature_types is None:
|
if self.feature_types is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user