Validate features for inplace predict. (#8359)
This commit is contained in:
parent
52977f0cdf
commit
c884b9e888
@ -406,10 +406,7 @@ def c_array(
|
|||||||
|
|
||||||
|
|
||||||
def _prediction_output(
|
def _prediction_output(
|
||||||
shape: CNumericPtr,
|
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
|
||||||
dims: c_bst_ulong,
|
|
||||||
predts: CFloatPtr,
|
|
||||||
is_cuda: bool
|
|
||||||
) -> NumpyOrCupy:
|
) -> NumpyOrCupy:
|
||||||
arr_shape = ctypes2numpy(shape, dims.value, np.uint64)
|
arr_shape = ctypes2numpy(shape, dims.value, np.uint64)
|
||||||
length = int(np.prod(arr_shape))
|
length = int(np.prod(arr_shape))
|
||||||
@ -1555,7 +1552,7 @@ class Booster:
|
|||||||
ctypes.byref(self.handle)))
|
ctypes.byref(self.handle)))
|
||||||
for d in cache:
|
for d in cache:
|
||||||
# Validate feature only after the feature names are saved into booster.
|
# Validate feature only after the feature names are saved into booster.
|
||||||
self._validate_features(d)
|
self._validate_dmatrix_features(d)
|
||||||
|
|
||||||
if isinstance(model_file, Booster):
|
if isinstance(model_file, Booster):
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
@ -1914,7 +1911,7 @@ class Booster:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||||
self._validate_features(dtrain)
|
self._validate_dmatrix_features(dtrain)
|
||||||
|
|
||||||
if fobj is None:
|
if fobj is None:
|
||||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
|
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
|
||||||
@ -1946,7 +1943,7 @@ class Booster:
|
|||||||
)
|
)
|
||||||
if not isinstance(dtrain, DMatrix):
|
if not isinstance(dtrain, DMatrix):
|
||||||
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
raise TypeError(f"invalid training matrix: {type(dtrain).__name__}")
|
||||||
self._validate_features(dtrain)
|
self._validate_dmatrix_features(dtrain)
|
||||||
|
|
||||||
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
||||||
c_array(ctypes.c_float, grad),
|
c_array(ctypes.c_float, grad),
|
||||||
@ -1982,7 +1979,7 @@ class Booster:
|
|||||||
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
|
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
|
||||||
if not isinstance(d[1], str):
|
if not isinstance(d[1], str):
|
||||||
raise TypeError(f"expected string, got {type(d[1]).__name__}")
|
raise TypeError(f"expected string, got {type(d[1]).__name__}")
|
||||||
self._validate_features(d[0])
|
self._validate_dmatrix_features(d[0])
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||||
@ -2033,7 +2030,7 @@ class Booster:
|
|||||||
result: str
|
result: str
|
||||||
Evaluation result string.
|
Evaluation result string.
|
||||||
"""
|
"""
|
||||||
self._validate_features(data)
|
self._validate_dmatrix_features(data)
|
||||||
return self.eval_set([(data, name)], iteration)
|
return self.eval_set([(data, name)], iteration)
|
||||||
|
|
||||||
# pylint: disable=too-many-function-args
|
# pylint: disable=too-many-function-args
|
||||||
@ -2136,7 +2133,7 @@ class Booster:
|
|||||||
if not isinstance(data, DMatrix):
|
if not isinstance(data, DMatrix):
|
||||||
raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
|
raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
|
||||||
if validate_features:
|
if validate_features:
|
||||||
self._validate_features(data)
|
self._validate_dmatrix_features(data)
|
||||||
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
|
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
|
||||||
args = {
|
args = {
|
||||||
"type": 0,
|
"type": 0,
|
||||||
@ -2184,8 +2181,8 @@ class Booster:
|
|||||||
base_margin: Any = None,
|
base_margin: Any = None,
|
||||||
strict_shape: bool = False
|
strict_shape: bool = False
|
||||||
) -> NumpyOrCupy:
|
) -> NumpyOrCupy:
|
||||||
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction does not
|
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction
|
||||||
cache the prediction result.
|
does not cache the prediction result.
|
||||||
|
|
||||||
Calling only ``inplace_predict`` in multiple threads is safe and lock
|
Calling only ``inplace_predict`` in multiple threads is safe and lock
|
||||||
free. But the safety does not hold when used in conjunction with other
|
free. But the safety does not hold when used in conjunction with other
|
||||||
@ -2273,18 +2270,22 @@ class Booster:
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .data import (
|
from .data import (
|
||||||
_is_pandas_df,
|
_array_interface,
|
||||||
_transform_pandas_df,
|
|
||||||
_is_cudf_df,
|
_is_cudf_df,
|
||||||
_is_cupy_array,
|
_is_cupy_array,
|
||||||
_array_interface,
|
_is_pandas_df,
|
||||||
|
_transform_pandas_df,
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_categorical = _has_categorical(self, data)
|
enable_categorical = _has_categorical(self, data)
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
data, _, _ = _transform_pandas_df(data, enable_categorical)
|
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
||||||
|
if validate_features:
|
||||||
|
self._validate_features(fns)
|
||||||
|
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
from .data import _ensure_np_dtype
|
from .data import _ensure_np_dtype
|
||||||
|
|
||||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterPredictFromDense(
|
_LIB.XGBoosterPredictFromDense(
|
||||||
@ -2334,10 +2335,13 @@ class Booster:
|
|||||||
return _prediction_output(shape, dims, preds, True)
|
return _prediction_output(shape, dims, preds, True)
|
||||||
if _is_cudf_df(data):
|
if _is_cudf_df(data):
|
||||||
from .data import _cudf_array_interfaces, _transform_cudf_df
|
from .data import _cudf_array_interfaces, _transform_cudf_df
|
||||||
data, cat_codes, _, _ = _transform_cudf_df(
|
|
||||||
|
data, cat_codes, fns, _ = _transform_cudf_df(
|
||||||
data, None, None, enable_categorical
|
data, None, None, enable_categorical
|
||||||
)
|
)
|
||||||
interfaces_str = _cudf_array_interfaces(data, cat_codes)
|
interfaces_str = _cudf_array_interfaces(data, cat_codes)
|
||||||
|
if validate_features:
|
||||||
|
self._validate_features(fns)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||||
self.handle,
|
self.handle,
|
||||||
@ -2723,40 +2727,55 @@ class Booster:
|
|||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
return df.sort(['Tree', 'Node']).reset_index(drop=True)
|
||||||
|
|
||||||
def _validate_features(self, data: DMatrix) -> None:
|
def _validate_dmatrix_features(self, data: DMatrix) -> None:
|
||||||
"""
|
|
||||||
Validate Booster and data's feature_names are identical.
|
|
||||||
Set feature_names and feature_types from DMatrix
|
|
||||||
"""
|
|
||||||
if data.num_row() == 0:
|
if data.num_row() == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
fn = data.feature_names
|
||||||
|
ft = data.feature_types
|
||||||
|
# Be consistent with versions before 1.7, "validate" actually modifies the
|
||||||
|
# booster.
|
||||||
if self.feature_names is None:
|
if self.feature_names is None:
|
||||||
self.feature_names = data.feature_names
|
self.feature_names = fn
|
||||||
self.feature_types = data.feature_types
|
if self.feature_types is None:
|
||||||
if data.feature_names is None and self.feature_names is not None:
|
self.feature_types = ft
|
||||||
raise ValueError(
|
|
||||||
"training data did not have the following fields: " +
|
|
||||||
", ".join(self.feature_names)
|
|
||||||
)
|
|
||||||
# Booster can't accept data with different feature names
|
|
||||||
if self.feature_names != data.feature_names:
|
|
||||||
dat_missing = set(cast(FeatureNames, self.feature_names)) - \
|
|
||||||
set(cast(FeatureNames, data.feature_names))
|
|
||||||
my_missing = set(cast(FeatureNames, data.feature_names)) - \
|
|
||||||
set(cast(FeatureNames, self.feature_names))
|
|
||||||
|
|
||||||
msg = 'feature_names mismatch: {0} {1}'
|
self._validate_features(fn)
|
||||||
|
|
||||||
|
def _validate_features(self, feature_names: Optional[FeatureNames]) -> None:
|
||||||
|
if self.feature_names is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if feature_names is None and self.feature_names is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"training data did not have the following fields: "
|
||||||
|
+ ", ".join(self.feature_names)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.feature_names != feature_names:
|
||||||
|
dat_missing = set(cast(FeatureNames, self.feature_names)) - set(
|
||||||
|
cast(FeatureNames, feature_names)
|
||||||
|
)
|
||||||
|
my_missing = set(cast(FeatureNames, feature_names)) - set(
|
||||||
|
cast(FeatureNames, self.feature_names)
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = "feature_names mismatch: {0} {1}"
|
||||||
|
|
||||||
if dat_missing:
|
if dat_missing:
|
||||||
msg += ('\nexpected ' + ', '.join(
|
msg += (
|
||||||
str(s) for s in dat_missing) + ' in input data')
|
"\nexpected "
|
||||||
|
+ ", ".join(str(s) for s in dat_missing)
|
||||||
|
+ " in input data"
|
||||||
|
)
|
||||||
|
|
||||||
if my_missing:
|
if my_missing:
|
||||||
msg += ('\ntraining data did not have the following fields: ' +
|
msg += (
|
||||||
', '.join(str(s) for s in my_missing))
|
"\ntraining data did not have the following fields: "
|
||||||
|
+ ", ".join(str(s) for s in my_missing)
|
||||||
|
)
|
||||||
|
|
||||||
raise ValueError(msg.format(self.feature_names, data.feature_names))
|
raise ValueError(msg.format(self.feature_names, feature_names))
|
||||||
|
|
||||||
def get_split_value_histogram(
|
def get_split_value_histogram(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import collections
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@ -998,34 +999,40 @@ def test_deprecate_position_arg():
|
|||||||
def test_pandas_input():
|
def test_pandas_input():
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.calibration import CalibratedClassifierCV
|
from sklearn.calibration import CalibratedClassifierCV
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
kRows = 100
|
kRows = 100
|
||||||
kCols = 6
|
kCols = 6
|
||||||
|
|
||||||
X = rng.randint(low=0, high=2, size=kRows*kCols)
|
X = rng.randint(low=0, high=2, size=kRows * kCols)
|
||||||
X = X.reshape(kRows, kCols)
|
X = X.reshape(kRows, kCols)
|
||||||
|
|
||||||
df = pd.DataFrame(X)
|
df = pd.DataFrame(X)
|
||||||
feature_names = []
|
feature_names = []
|
||||||
for i in range(1, kCols):
|
for i in range(1, kCols):
|
||||||
feature_names += ['k'+str(i)]
|
feature_names += ["k" + str(i)]
|
||||||
|
|
||||||
df.columns = ['status'] + feature_names
|
df.columns = ["status"] + feature_names
|
||||||
|
|
||||||
target = df['status']
|
target = df["status"]
|
||||||
train = df.drop(columns=['status'])
|
train = df.drop(columns=["status"])
|
||||||
model = xgb.XGBClassifier()
|
model = xgb.XGBClassifier()
|
||||||
model.fit(train, target)
|
model.fit(train, target)
|
||||||
np.testing.assert_equal(model.feature_names_in_, np.array(feature_names))
|
np.testing.assert_equal(model.feature_names_in_, np.array(feature_names))
|
||||||
|
|
||||||
clf_isotonic = CalibratedClassifierCV(model,
|
columns = list(train.columns)
|
||||||
cv='prefit', method='isotonic')
|
random.shuffle(columns, lambda: 0.1)
|
||||||
|
df_incorrect = df[columns]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model.predict(df_incorrect)
|
||||||
|
|
||||||
|
clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic")
|
||||||
clf_isotonic.fit(train, target)
|
clf_isotonic.fit(train, target)
|
||||||
assert isinstance(clf_isotonic.calibrated_classifiers_[0].base_estimator,
|
assert isinstance(
|
||||||
xgb.XGBClassifier)
|
clf_isotonic.calibrated_classifiers_[0].base_estimator, xgb.XGBClassifier
|
||||||
np.testing.assert_allclose(np.array(clf_isotonic.classes_),
|
)
|
||||||
np.array([0, 1]))
|
np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1]))
|
||||||
|
|
||||||
|
|
||||||
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
|
def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user