Use config_context in sklearn interface. (#8141)

This commit is contained in:
Jiaming Yuan 2022-08-09 14:48:54 +08:00 committed by GitHub
parent 03cc3b359c
commit 9ae547f994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 560 additions and 438 deletions

View File

@ -1,3 +1,5 @@
.. _dask-examples:
XGBoost Dask Feature Walkthrough
================================

View File

@ -126,7 +126,7 @@ master_doc = 'index'
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"
autoclass_content = 'both'

View File

@ -115,7 +115,7 @@ Alternatively, XGBoost also implements the Scikit-Learn interface with
:py:class:`~xgboost.dask.DaskXGBRanker` and 2 random forest variances. This wrapper is
similar to the single node Scikit-Learn interface in xgboost, with dask collection as
inputs and has an additional ``client`` attribute. See following sections and
:ref:`sphx_glr_python_dask-examples` for more examples.
:ref:`dask-examples` for more examples.
******************

View File

@ -16,6 +16,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
dask
spark_estimator
ray
dart
monotonic

View File

@ -70,6 +70,23 @@ def config_doc(
# Suppress warning caused by model generated with XGBoost version < 1.0.0
bst = xgb.Booster(model_file='./old_model.bin')
assert xgb.get_config()['verbosity'] == 2 # old value restored
Nested configuration context is also supported:
Example
-------
.. code-block:: python
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=2):
assert xgb.get_config()["verbosity"] == 2
xgb.set_config(verbosity=2)
assert xgb.get_config()["verbosity"] == 2
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
"""
def none_to_str(value: Optional[str]) -> str:
@ -98,7 +115,11 @@ def config_doc(
Keyword arguments representing the parameters and their values
""")
def set_config(**new_config: Any) -> None:
config = json.dumps(new_config)
not_none = {}
for k, v in new_config.items():
if v is not None:
not_none[k] = v
config = json.dumps(not_none)
_check_call(_LIB.XGBSetGlobalConfig(c_str(config)))

View File

@ -1,43 +1,49 @@
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines
"""Scikit-Learn Wrapper interface for XGBoost."""
import copy
import warnings
import json
import os
import warnings
from typing import (
Union,
Optional,
List,
Dict,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Any,
TypeVar,
Type,
TypeVar,
Union,
cast,
)
import numpy as np
from scipy.special import softmax
from .core import Booster, DMatrix, XGBoostError
from .core import _deprecate_positional_args, _convert_ntree_limit
from .core import Metric
from .training import train
from .callback import TrainingCallback
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
from ._typing import ArrayLike, FeatureNames, FeatureTypes
from .callback import TrainingCallback
# Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn
from .compat import (
SKLEARN_INSTALLED,
XGBModelBase,
XGBClassifierBase,
XGBRegressorBase,
XGBModelBase,
XGBoostLabelEncoder,
XGBRegressorBase,
)
from .config import config_context
from .core import (
Booster,
DMatrix,
Metric,
XGBoostError,
_convert_ntree_limit,
_deprecate_positional_args,
)
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
from .training import train
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
@ -59,9 +65,7 @@ def _check_rf_callback(
_SklObjective = Optional[
Union[
str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
]
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
]
@ -95,10 +99,12 @@ def _objective_decorator(
The training set from which the labels will be extracted using
``dmatrix.get_label()``
"""
def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
"""internal function"""
labels = dmatrix.get_label()
return func(labels, preds)
return inner
@ -109,19 +115,21 @@ def _metric_decorator(func: Callable) -> Metric:
is compatible with :py:func:`train`
"""
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
return func.__name__, func(y_true, y_score)
return inner
__estimator_doc = '''
__estimator_doc = """
n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting
rounds.
'''
"""
__model_doc = f'''
__model_doc = f"""
max_depth : Optional[int]
Maximum tree depth for base learners.
max_leaves :
@ -146,10 +154,10 @@ __model_doc = f'''
recommended to study this option from the parameters document :doc:`tree method
</treemethod>`
n_jobs : Optional[int]
Number of parallel threads used to run xgboost. When used with other Scikit-Learn
algorithms like grid search, you may choose which algorithm to parallelize and
balance the threads. Creating thread contention will significantly slow down both
algorithms.
Number of parallel threads used to run xgboost. When used with other
Scikit-Learn algorithms like grid search, you may choose which algorithm to
parallelize and balance the threads. Creating thread contention will
significantly slow down both algorithms.
gamma : Optional[float]
(min_split_loss) Minimum loss reduction required to make a further partition on a
leaf node of the tree.
@ -332,9 +340,9 @@ __model_doc = f'''
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee
that parameters passed via this argument will interact properly
with scikit-learn.
'''
"""
__custom_obj_note = '''
__custom_obj_note = """
.. note:: Custom objective function
A custom objective function can be provided for the ``objective``
@ -350,15 +358,16 @@ __custom_obj_note = '''
The value of the gradient for each sample point.
hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
'''
"""
def xgboost_model_doc(
header: str, items: List[str],
header: str,
items: List[str],
extra_parameters: Optional[str] = None,
end_note: Optional[str] = None
end_note: Optional[str] = None,
) -> Callable[[Type], Type]:
'''Obtain documentation for Scikit-Learn wrappers
"""Obtain documentation for Scikit-Learn wrappers
Parameters
----------
@ -372,29 +381,34 @@ def xgboost_model_doc(
extra_parameters: str
Document for class specific parameters, placed at the head.
end_note: str
Extra notes put to the end.
'''
Extra notes put to the end."""
def get_doc(item: str) -> str:
'''Return selected item'''
__doc = {'estimators': __estimator_doc,
'model': __model_doc,
'objective': __custom_obj_note}
"""Return selected item"""
__doc = {
"estimators": __estimator_doc,
"model": __model_doc,
"objective": __custom_obj_note,
}
return __doc[item]
def adddoc(cls: Type) -> Type:
doc = ['''
doc = [
"""
Parameters
----------
''']
"""
]
if extra_parameters:
doc.append(extra_parameters)
doc.extend([get_doc(i) for i in items])
if end_note:
doc.append(end_note)
full_doc = [header + '\n\n']
full_doc = [header + "\n\n"]
full_doc.extend(doc)
cls.__doc__ = ''.join(full_doc)
cls.__doc__ = "".join(full_doc)
return cls
return adddoc
@ -416,9 +430,7 @@ def _wrap_evaluation_matrices(
enable_categorical: bool,
feature_types: Optional[FeatureTypes],
) -> Tuple[Any, List[Tuple[Any, str]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
"""
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way."""
train_dmatrix = create_dmatrix(
data=X,
label=y,
@ -439,8 +451,8 @@ def _wrap_evaluation_matrices(
return [None] * n_validation
if len(meta) != n_validation:
raise ValueError(
f"{name}'s length does not equal `eval_set`'s length, " +
f"expecting {n_validation}, got {len(meta)}"
f"{name}'s length does not equal `eval_set`'s length, "
+ f"expecting {n_validation}, got {len(meta)}"
)
return meta
@ -459,11 +471,12 @@ def _wrap_evaluation_matrices(
# Skip the duplicated entry.
if all(
(
valid_X is X, valid_y is y,
valid_X is X,
valid_y is y,
sample_weight_eval_set[i] is sample_weight,
base_margin_eval_set[i] is base_margin,
eval_group[i] is group,
eval_qid[i] is qid
eval_qid[i] is qid,
)
):
evals.append(train_dmatrix)
@ -502,8 +515,10 @@ def _wrap_evaluation_matrices(
return train_dmatrix, evals
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model', 'objective'])
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""",
["estimators", "model", "objective"],
)
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
def __init__(
@ -546,7 +561,7 @@ class XGBModel(XGBModelBase):
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
if not SKLEARN_INSTALLED:
raise ImportError(
@ -595,8 +610,8 @@ class XGBModel(XGBModelBase):
self.kwargs = kwargs
def _more_tags(self) -> Dict[str, bool]:
'''Tags used for scikit-learn data validation.'''
return {'allow_nan': True, 'no_validation': True}
"""Tags used for scikit-learn data validation."""
return {"allow_nan": True, "no_validation": True}
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
@ -612,7 +627,8 @@ class XGBModel(XGBModelBase):
"""
if not self.__sklearn_is_fitted__():
from sklearn.exceptions import NotFittedError
raise NotFittedError('need to call fit or load_model beforehand')
raise NotFittedError("need to call fit or load_model beforehand")
return self._Booster
def set_params(self, **params: Any) -> "XGBModel":
@ -640,7 +656,7 @@ class XGBModel(XGBModelBase):
self.kwargs = {}
self.kwargs[key] = value
if hasattr(self, '_Booster'):
if hasattr(self, "_Booster"):
parameters = self.get_xgb_params()
self.get_booster().set_param(parameters)
@ -662,9 +678,10 @@ class XGBModel(XGBModelBase):
# if kwargs is a dict, update params accordingly
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
params.update(self.kwargs)
if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(
np.iinfo(np.int32).max)
if isinstance(params["random_state"], np.random.RandomState):
params["random_state"] = params["random_state"].randint(
np.iinfo(np.int32).max
)
def parse_parameter(value: Any) -> Optional[Union[int, float, str]]:
for t in (int, float, str):
@ -683,7 +700,7 @@ class XGBModel(XGBModelBase):
while stack:
obj = stack.pop()
for k, v in obj.items():
if k.endswith('_param'):
if k.endswith("_param"):
for p_k, p_v in v.items():
internal[p_k] = p_v
elif isinstance(v, dict):
@ -722,7 +739,7 @@ class XGBModel(XGBModelBase):
return self.n_estimators
def _get_type(self) -> str:
if not hasattr(self, '_estimator_type'):
if not hasattr(self, "_estimator_type"):
raise TypeError(
"`_estimator_type` undefined. "
"Please use appropriate mixin to define estimator type."
@ -732,14 +749,14 @@ class XGBModel(XGBModelBase):
def save_model(self, fname: Union[str, os.PathLike]) -> None:
meta: Dict[str, Any] = {}
for k, v in self.__dict__.items():
if k == '_le':
meta['_le'] = self._le.to_json()
if k == "_le":
meta["_le"] = self._le.to_json()
continue
if k == '_Booster':
if k == "_Booster":
continue
if k == 'classes_':
if k == "classes_":
# numpy array is not JSON serializable
meta['classes_'] = self.classes_.tolist()
meta["classes_"] = self.classes_.tolist()
continue
if k == "feature_types":
# Use the `feature_types` attribute from booster instead.
@ -749,8 +766,10 @@ class XGBModel(XGBModelBase):
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
meta['_estimator_type'] = self._get_type()
warnings.warn(
str(k) + " is not saved in Scikit-Learn meta.", UserWarning
)
meta["_estimator_type"] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)
self.get_booster().save_model(fname)
@ -761,27 +780,25 @@ class XGBModel(XGBModelBase):
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
# pylint: disable=attribute-defined-outside-init
if not hasattr(self, '_Booster'):
self._Booster = Booster({'n_jobs': self.n_jobs})
if not hasattr(self, "_Booster"):
self._Booster = Booster({"n_jobs": self.n_jobs})
self.get_booster().load_model(fname)
meta_str = self.get_booster().attr('scikit_learn')
meta_str = self.get_booster().attr("scikit_learn")
if meta_str is None:
# FIXME(jiaming): This doesn't have to be a problem as most of the needed
# information like num_class and objective is in Learner class.
warnings.warn(
'Loading a native XGBoost model with Scikit-Learn interface.'
)
warnings.warn("Loading a native XGBoost model with Scikit-Learn interface.")
return
meta = json.loads(meta_str)
states = {}
for k, v in meta.items():
if k == '_le':
if k == "_le":
self._le = XGBoostLabelEncoder()
self._le.from_json(v)
continue
# FIXME(jiaming): This can be removed once label encoder is gone since we can
# generate it from `np.arange(self.n_classes_)`
if k == 'classes_':
if k == "classes_":
self.classes_ = np.array(v)
continue
if k == "feature_types":
@ -907,7 +924,7 @@ class XGBModel(XGBModelBase):
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBModel":
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model.
@ -963,6 +980,7 @@ class XGBModel(XGBModelBase):
.. deprecated:: 1.6.0
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
with config_context(verbosity=self.verbosity):
evals_result: TrainingCallback.EvalsLog = {}
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
@ -980,7 +998,7 @@ class XGBModel(XGBModelBase):
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
enable_categorical=self.enable_categorical,
feature_types=self.feature_types
feature_types=self.feature_types,
)
params = self.get_xgb_params()
@ -992,7 +1010,13 @@ class XGBModel(XGBModelBase):
else:
obj = None
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
(
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
self._Booster = train(
@ -1025,7 +1049,7 @@ class XGBModel(XGBModelBase):
def _get_iteration_range(
self, iteration_range: Optional[Tuple[int, int]]
) -> Tuple[int, int]:
if (iteration_range is None or iteration_range[1] == 0):
if iteration_range is None or iteration_range[1] == 0:
# Use best_iteration if defined.
try:
iteration_range = (0, self.best_iteration + 1)
@ -1067,8 +1091,8 @@ class XGBModel(XGBModelBase):
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying ``iteration_range=(10,
20)``, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
20)``, then only the forests built during [10, 20) (half open set) rounds
are used in this prediction.
.. versionadded:: 1.4.0
@ -1077,6 +1101,7 @@ class XGBModel(XGBModelBase):
prediction
"""
with config_context(verbosity=self.verbosity):
iteration_range = _convert_ntree_limit(
self.get_booster(), ntree_limit, iteration_range
)
@ -1093,6 +1118,7 @@ class XGBModel(XGBModelBase):
)
if _is_cupy_array(predts):
import cupy # pylint: disable=import-error
predts = cupy.asnumpy(predts) # ensure numpy array is used.
return predts
except TypeError:
@ -1105,7 +1131,7 @@ class XGBModel(XGBModelBase):
missing=self.missing,
nthread=self.n_jobs,
feature_types=self.feature_types,
enable_categorical=self.enable_categorical
enable_categorical=self.enable_categorical,
)
return self.get_booster().predict(
data=test,
@ -1115,9 +1141,10 @@ class XGBModel(XGBModelBase):
)
def apply(
self, X: ArrayLike,
self,
X: ArrayLike,
ntree_limit: int = 0,
iteration_range: Optional[Tuple[int, int]] = None
iteration_range: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""Return the predicted leaf every tree for each sample. If the model is trained with
early stopping, then `best_iteration` is used automatically.
@ -1141,17 +1168,19 @@ class XGBModel(XGBModelBase):
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
"""
with config_context(verbosity=self.verbosity):
iteration_range = _convert_ntree_limit(
self.get_booster(), ntree_limit, iteration_range
)
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = DMatrix(
X, missing=self.missing, feature_types=self.feature_types, nthread=self.n_jobs
X,
missing=self.missing,
feature_types=self.feature_types,
nthread=self.n_jobs,
)
return self.get_booster().predict(
test_dmatrix,
pred_leaf=True,
iteration_range=iteration_range
test_dmatrix, pred_leaf=True, iteration_range=iteration_range
)
def evals_result(self) -> Dict[str, Dict[str, List[float]]]:
@ -1208,13 +1237,13 @@ class XGBModel(XGBModelBase):
return getattr(booster, attr)
except AttributeError as e:
raise AttributeError(
f'`{attr}` in only defined when early stopping is used.'
f"`{attr}` in only defined when early stopping is used."
) from e
@property
def best_score(self) -> float:
"""The best score obtained by early stopping."""
return float(self._early_stopping_attr('best_score'))
return float(self._early_stopping_attr("best_score"))
@property
def best_iteration(self) -> int:
@ -1222,11 +1251,11 @@ class XGBModel(XGBModelBase):
for instance if the best iteration is the first round, then best_iteration is 0.
"""
return int(self._early_stopping_attr('best_iteration'))
return int(self._early_stopping_attr("best_iteration"))
@property
def best_ntree_limit(self) -> int:
return int(self._early_stopping_attr('best_ntree_limit'))
return int(self._early_stopping_attr("best_ntree_limit"))
@property
def feature_importances_(self) -> np.ndarray:
@ -1243,6 +1272,7 @@ class XGBModel(XGBModelBase):
def dft() -> str:
return "weight" if self.booster == "gblinear" else "gain"
score = b.get_score(
importance_type=self.importance_type if self.importance_type else dft()
)
@ -1251,7 +1281,7 @@ class XGBModel(XGBModelBase):
else:
feature_names = b.feature_names
# gblinear returns all features so the `get` in next line is only for gbtree.
all_features = [score.get(f, 0.) for f in feature_names]
all_features = [score.get(f, 0.0) for f in feature_names]
all_features_arr = np.array(all_features, dtype=np.float32)
total = all_features_arr.sum()
if total == 0:
@ -1273,14 +1303,14 @@ class XGBModel(XGBModelBase):
-------
coef_ : array of shape ``[n_features]`` or ``[n_classes, n_features]``
"""
if self.get_params()['booster'] != 'gblinear':
if self.get_params()["booster"] != "gblinear":
raise AttributeError(
f"Coefficients are not defined for Booster type {self.booster}"
)
b = self.get_booster()
coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
coef = np.array(json.loads(b.get_dump(dump_format="json")[0])["weight"])
# Logic for multiclass classification
n_classes = getattr(self, 'n_classes_', None)
n_classes = getattr(self, "n_classes_", None)
if n_classes is not None:
if n_classes > 2:
assert len(coef.shape) == 1
@ -1303,12 +1333,12 @@ class XGBModel(XGBModelBase):
-------
intercept_ : array of shape ``(1,)`` or ``[n_classes]``
"""
if self.get_params()['booster'] != 'gblinear':
if self.get_params()["booster"] != "gblinear":
raise AttributeError(
f"Intercept (bias) is not defined for Booster type {self.booster}"
)
b = self.get_booster()
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
return np.array(json.loads(b.get_dump(dump_format="json")[0])["bias"])
PredtT = TypeVar("PredtT", bound=np.ndarray)
@ -1334,10 +1364,12 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost classification.",
['model', 'objective'], extra_parameters='''
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of boosting rounds.
''')
""",
)
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@_deprecate_positional_args
@ -1346,7 +1378,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
*,
objective: _SklObjective = "binary:logistic",
use_label_encoder: Optional[bool] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
# must match the parameters for `get_params`
self.use_label_encoder = use_label_encoder
@ -1372,9 +1404,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBClassifier":
# pylint: disable = attribute-defined-outside-init,too-many-statements
with config_context(verbosity=self.verbosity):
evals_result: TrainingCallback.EvalsLog = {}
if _is_cudf_df(y) or _is_cudf_ser(y):
@ -1419,7 +1452,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
(
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
train_dmatrix, evals = _wrap_evaluation_matrices(
@ -1463,8 +1502,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
assert XGBModel.fit.__doc__ is not None
fit.__doc__ = XGBModel.fit.__doc__.replace(
'Fit gradient boosting model',
'Fit gradient boosting classifier', 1)
"Fit gradient boosting model", "Fit gradient boosting classifier", 1
)
def predict(
self,
@ -1475,6 +1514,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
with config_context(verbosity=self.verbosity):
class_probs = super().predict(
X=X,
output_margin=output_margin,
@ -1501,7 +1541,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1
if hasattr(self, '_le'):
if hasattr(self, "_le"):
return self._le.inverse_transform(column_indexes)
return column_indexes
@ -1513,7 +1553,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
""" Predict the probability of each `X` example being of a given class.
"""Predict the probability of each `X` example being of a given class.
.. note:: This function is only thread safe for `gbtree` and `dart`.
@ -1552,7 +1592,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
output_margin=True
output_margin=True,
)
class_prob = softmax(raw_predt, axis=1)
return class_prob
@ -1561,106 +1601,40 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
ntree_limit=ntree_limit,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range
iteration_range=iteration_range,
)
# If model is loaded from a raw booster there's no `n_classes_`
return _cls_predict_proba(getattr(self, "n_classes_", 0), class_probs, np.vstack)
return _cls_predict_proba(
getattr(self, "n_classes_", 0), class_probs, np.vstack
)
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.",
['model', 'objective'],
extra_parameters='''
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
''')
""",
)
class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
self, *,
self,
*,
learning_rate: float = 1.0,
subsample: float = 0.8,
colsample_bynode: float = 0.8,
reg_lambda: float = 1e-5,
**kwargs: Any
**kwargs: Any,
):
super().__init__(learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params['num_parallel_tree'] = self.n_estimators
return params
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: ArrayLike,
y: ArrayLike,
*,
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None
) -> "XGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
super().fit(**args)
return self
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost regression.",
['estimators', 'model', 'objective'])
class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
self, *, objective: _SklObjective = "reg:squarederror", **kwargs: Any
) -> None:
super().__init__(objective=objective, **kwargs)
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
['model', 'objective'], extra_parameters='''
n_estimators : int
Number of trees in random forest to fit.
''')
class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: float = 1.0,
subsample: float = 0.8,
colsample_bynode: float = 0.8,
reg_lambda: float = 1e-5,
**kwargs: Any
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs
**kwargs,
)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
@ -1689,7 +1663,82 @@ class XGBRFRegressor(XGBRegressor):
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRFClassifier":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
super().fit(**args)
return self
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost regression.",
["estimators", "model", "objective"],
)
class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
self, *, objective: _SklObjective = "reg:squarederror", **kwargs: Any
) -> None:
super().__init__(objective=objective, **kwargs)
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: float = 1.0,
subsample: float = 0.8,
colsample_bynode: float = 0.8,
reg_lambda: float = 1e-5,
**kwargs: Any,
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs,
)
_check_rf_callback(self.early_stopping_rounds, self.callbacks)
def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params
def get_num_boosting_rounds(self) -> int:
return 1
# pylint: disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: ArrayLike,
y: ArrayLike,
*,
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRFRegressor":
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
_check_rf_callback(early_stopping_rounds, callbacks)
@ -1698,9 +1747,9 @@ class XGBRFRegressor(XGBRegressor):
@xgboost_model_doc(
'Implementation of the Scikit-Learn API for XGBoost Ranking.',
['estimators', 'model'],
end_note='''
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
["estimators", "model"],
end_note="""
.. note::
A custom objective function is currently not supported by XGBRanker.
@ -1737,7 +1786,8 @@ class XGBRFRegressor(XGBRegressor):
then your group array should be ``[3, 4]``. Sometimes using query id (`qid`)
instead of group can be more convenient.
''')
""",
)
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
@ -1768,7 +1818,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBRanker":
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker
@ -1853,13 +1903,15 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
# check if group information is provided
with config_context(verbosity=self.verbosity):
if group is None and qid is None:
raise ValueError("group or qid is required for ranking task")
if eval_set is not None:
if eval_group is None and eval_qid is None:
raise ValueError(
"eval_group or eval_qid is required if eval_set is not None")
"eval_group or eval_qid is required if eval_set is not None"
)
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
X=X,
@ -1882,12 +1934,18 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals_result: TrainingCallback.EvalsLog = {}
params = self.get_xgb_params()
model, metric, params, early_stopping_rounds, callbacks = self._configure_fit(
(
model,
metric,
params,
early_stopping_rounds,
callbacks,
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
if callable(metric):
raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.'
"Custom evaluation metric is not yet supported for XGBRanker."
)
self._Booster = train(
@ -1898,8 +1956,9 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals=evals,
evals_result=evals_result,
custom_metric=metric,
verbose_eval=verbose, xgb_model=model,
callbacks=callbacks
verbose_eval=verbose,
xgb_model=model,
callbacks=callbacks,
)
self.objective = params["objective"]

View File

@ -113,7 +113,9 @@ if __name__ == "__main__":
run_formatter(path)
for path in [
"python-package/xgboost/dask.py",
"python-package/xgboost/sklearn.py",
"python-package/xgboost/spark",
"tests/python/test_config.py",
"tests/python/test_spark/test_data.py",
"tests/python-gpu/test_gpu_spark/test_data.py",
"tests/ci_build/lint_python.py",

View File

@ -1,13 +1,15 @@
# -*- coding: utf-8 -*-
import xgboost as xgb
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import pytest
import testing as tm
import xgboost as xgb
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3])
@pytest.mark.parametrize("verbosity_level", [0, 1, 2, 3])
def test_global_config_verbosity(verbosity_level):
def get_current_verbosity():
return xgb.get_config()['verbosity']
return xgb.get_config()["verbosity"]
old_verbosity = get_current_verbosity()
with xgb.config_context(verbosity=verbosity_level):
@ -16,13 +18,48 @@ def test_global_config_verbosity(verbosity_level):
assert old_verbosity == get_current_verbosity()
@pytest.mark.parametrize('use_rmm', [False, True])
@pytest.mark.parametrize("use_rmm", [False, True])
def test_global_config_use_rmm(use_rmm):
def get_current_use_rmm_flag():
return xgb.get_config()['use_rmm']
return xgb.get_config()["use_rmm"]
old_use_rmm_flag = get_current_use_rmm_flag()
with xgb.config_context(use_rmm=use_rmm):
new_use_rmm_flag = get_current_use_rmm_flag()
assert new_use_rmm_flag == use_rmm
assert old_use_rmm_flag == get_current_use_rmm_flag()
def test_nested_config():
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=2):
assert xgb.get_config()["verbosity"] == 2
with xgb.config_context(verbosity=1):
assert xgb.get_config()["verbosity"] == 1
assert xgb.get_config()["verbosity"] == 2
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
with xgb.config_context(verbosity=None):
assert xgb.get_config()["verbosity"] == 3 # None has no effect
verbosity = xgb.get_config()["verbosity"]
xgb.set_config(verbosity=2)
assert xgb.get_config()["verbosity"] == 2
with xgb.config_context(verbosity=3):
assert xgb.get_config()["verbosity"] == 3
xgb.set_config(verbosity=verbosity) # reset
def test_thread_safty():
n_threads = multiprocessing.cpu_count()
futures = []
with ThreadPoolExecutor(max_workers=n_threads) as executor:
for i in range(256):
f = executor.submit(test_nested_config)
futures.append(f)
for f in futures:
f.result()