Validate features for inplace predict. (#8359)

This commit is contained in:
Jiaming Yuan
2022-10-19 23:05:36 +08:00
committed by GitHub
parent 52977f0cdf
commit c884b9e888
2 changed files with 78 additions and 52 deletions

View File

@@ -406,10 +406,7 @@ def c_array(
def _prediction_output(
shape: CNumericPtr,
dims: c_bst_ulong,
predts: CFloatPtr,
is_cuda: bool
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
) -> NumpyOrCupy:
arr_shape = ctypes2numpy(shape, dims.value, np.uint64)
length = int(np.prod(arr_shape))
@@ -1555,7 +1552,7 @@ class Booster:
ctypes.byref(self.handle)))
for d in cache:
# Validate feature only after the feature names are saved into booster.
self._validate_features(d)
self._validate_dmatrix_features(d)
if isinstance(model_file, Booster):
assert self.handle is not None
@@ -1914,7 +1911,7 @@ class Booster:
"""
if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_features(dtrain)
self._validate_dmatrix_features(dtrain)
if fobj is None:
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
@@ -1946,7 +1943,7 @@ class Booster:
)
if not isinstance(dtrain, DMatrix):
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
self._validate_features(dtrain)
self._validate_dmatrix_features(dtrain)
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
c_array(ctypes.c_float, grad),
@@ -1982,7 +1979,7 @@ class Booster:
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
if not isinstance(d[1], str):
raise TypeError(f"expected string, got {type(d[1]).__name__}")
self._validate_features(d[0])
self._validate_dmatrix_features(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
@@ -2033,7 +2030,7 @@ class Booster:
result: str
Evaluation result string.
"""
self._validate_features(data)
self._validate_dmatrix_features(data)
return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args
@@ -2136,7 +2133,7 @@ class Booster:
if not isinstance(data, DMatrix):
raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
if validate_features:
self._validate_features(data)
self._validate_dmatrix_features(data)
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
args = {
"type": 0,
@@ -2184,8 +2181,8 @@ class Booster:
base_margin: Any = None,
strict_shape: bool = False
) -> NumpyOrCupy:
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not
cache the prediction result.
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction
does not cache the prediction result.
Calling only ``inplace_predict`` in multiple threads is safe and lock
free. But the safety does not hold when used in conjunction with other
@@ -2273,18 +2270,22 @@ class Booster:
)
from .data import (
_is_pandas_df,
_transform_pandas_df,
_array_interface,
_is_cudf_df,
_is_cupy_array,
_array_interface,
_is_pandas_df,
_transform_pandas_df,
)
enable_categorical = _has_categorical(self, data)
if _is_pandas_df(data):
data, _, _ = _transform_pandas_df(data, enable_categorical)
data, fns, _ = _transform_pandas_df(data, enable_categorical)
if validate_features:
self._validate_features(fns)
if isinstance(data, np.ndarray):
from .data import _ensure_np_dtype
data, _ = _ensure_np_dtype(data, data.dtype)
_check_call(
_LIB.XGBoosterPredictFromDense(
@@ -2334,10 +2335,13 @@ class Booster:
return _prediction_output(shape, dims, preds, True)
if _is_cudf_df(data):
from .data import _cudf_array_interfaces, _transform_cudf_df
data, cat_codes, _, _ = _transform_cudf_df(
data, cat_codes, fns, _ = _transform_cudf_df(
data, None, None, enable_categorical
)
interfaces_str = _cudf_array_interfaces(data, cat_codes)
if validate_features:
self._validate_features(fns)
_check_call(
_LIB.XGBoosterPredictFromCudaColumnar(
self.handle,
@@ -2723,40 +2727,55 @@ class Booster:
# pylint: disable=no-member
return df.sort(['Tree', 'Node']).reset_index(drop=True)
def _validate_features(self, data: DMatrix) -> None:
"""
Validate Booster and data's feature_names are identical.
Set feature_names and feature_types from DMatrix
"""
def _validate_dmatrix_features(self, data: DMatrix) -> None:
if data.num_row() == 0:
return
fn = data.feature_names
ft = data.feature_types
# Be consistent with versions before 1.7, "validate" actually modifies the
# booster.
if self.feature_names is None:
self.feature_names = data.feature_names
self.feature_types = data.feature_types
if data.feature_names is None and self.feature_names is not None:
raise ValueError(
"training data did not have the following fields: " +
", ".join(self.feature_names)
)
# Booster can't accept data with different feature names
if self.feature_names != data.feature_names:
dat_missing = set(cast(FeatureNames, self.feature_names)) - \
set(cast(FeatureNames, data.feature_names))
my_missing = set(cast(FeatureNames, data.feature_names)) - \
set(cast(FeatureNames, self.feature_names))
self.feature_names = fn
if self.feature_types is None:
self.feature_types = ft
msg = 'feature_names mismatch: {0} {1}'
self._validate_features(fn)
def _validate_features(self, feature_names: Optional[FeatureNames]) -> None:
if self.feature_names is None:
return
if feature_names is None and self.feature_names is not None:
raise ValueError(
"training data did not have the following fields: "
+ ", ".join(self.feature_names)
)
if self.feature_names != feature_names:
dat_missing = set(cast(FeatureNames, self.feature_names)) - set(
cast(FeatureNames, feature_names)
)
my_missing = set(cast(FeatureNames, feature_names)) - set(
cast(FeatureNames, self.feature_names)
)
msg = "feature_names mismatch: {0} {1}"
if dat_missing:
msg += ('\nexpected ' + ', '.join(
str(s) for s in dat_missing) + ' in input data')
msg += (
"\nexpected "
+ ", ".join(str(s) for s in dat_missing)
+ " in input data"
)
if my_missing:
msg += ('\ntraining data did not have the following fields: ' +
', '.join(str(s) for s in my_missing))
msg += (
"\ntraining data did not have the following fields: "
+ ", ".join(str(s) for s in my_missing)
)
raise ValueError(msg.format(self.feature_names, data.feature_names))
raise ValueError(msg.format(self.feature_names, feature_names))
def get_split_value_histogram(
self,