diff --git a/demo/dask/README.rst b/demo/dask/README.rst index 456425e91..6d44031b0 100644 --- a/demo/dask/README.rst +++ b/demo/dask/README.rst @@ -1,3 +1,5 @@ +.. _dask-examples: + XGBoost Dask Feature Walkthrough ================================ diff --git a/doc/conf.py b/doc/conf.py index 7ebd0c2af..c362709d5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -126,7 +126,7 @@ master_doc = 'index' # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" autoclass_content = 'both' diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 44bb643a2..2eebeda11 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -115,7 +115,7 @@ Alternatively, XGBoost also implements the Scikit-Learn interface with :py:class:`~xgboost.dask.DaskXGBRanker` and 2 random forest variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost, with dask collection as inputs and has an additional ``client`` attribute. See following sections and -:ref:`sphx_glr_python_dask-examples` for more examples. +:ref:`dask-examples` for more examples. ****************** diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 2d0ec8a2f..4bdc448da 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -16,6 +16,7 @@ See `Awesome XGBoost `_ for mo Distributed XGBoost with XGBoost4J-Spark Distributed XGBoost with XGBoost4J-Spark-GPU dask + spark_estimator ray dart monotonic diff --git a/python-package/xgboost/config.py b/python-package/xgboost/config.py index 2344ae4a3..34948feed 100644 --- a/python-package/xgboost/config.py +++ b/python-package/xgboost/config.py @@ -70,6 +70,23 @@ def config_doc( # Suppress warning caused by model generated with XGBoost version < 1.0.0 bst = xgb.Booster(model_file='./old_model.bin') assert xgb.get_config()['verbosity'] == 2 # old value restored + + Nested configuration context is also supported: + + Example + ------- + + .. code-block:: python + + with xgb.config_context(verbosity=3): + assert xgb.get_config()["verbosity"] == 3 + with xgb.config_context(verbosity=2): + assert xgb.get_config()["verbosity"] == 2 + + xgb.set_config(verbosity=2) + assert xgb.get_config()["verbosity"] == 2 + with xgb.config_context(verbosity=3): + assert xgb.get_config()["verbosity"] == 3 """ def none_to_str(value: Optional[str]) -> str: @@ -98,7 +115,11 @@ def config_doc( Keyword arguments representing the parameters and their values """) def set_config(**new_config: Any) -> None: - config = json.dumps(new_config) + not_none = {} + for k, v in new_config.items(): + if v is not None: + not_none[k] = v + config = json.dumps(not_none) _check_call(_LIB.XGBSetGlobalConfig(c_str(config))) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index a48a0f45c..4d7e5a624 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1,43 +1,49 @@ # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines """Scikit-Learn Wrapper interface for XGBoost.""" import copy -import warnings import json import os +import warnings from typing import ( - Union, - Optional, - List, - Dict, + Any, Callable, + Dict, + List, + Optional, Sequence, Tuple, - Any, - TypeVar, Type, + TypeVar, + Union, cast, ) import numpy as np from scipy.special import softmax -from .core import Booster, DMatrix, XGBoostError -from .core import _deprecate_positional_args, _convert_ntree_limit -from .core import Metric -from .training import train -from .callback import TrainingCallback -from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array from ._typing import ArrayLike, FeatureNames, FeatureTypes +from .callback import TrainingCallback # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn from .compat import ( SKLEARN_INSTALLED, - XGBModelBase, XGBClassifierBase, - XGBRegressorBase, + XGBModelBase, XGBoostLabelEncoder, + XGBRegressorBase, ) +from .config import config_context +from .core import ( + Booster, + DMatrix, + Metric, + XGBoostError, + _convert_ntree_limit, + _deprecate_positional_args, +) +from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array +from .training import train class XGBRankerMixIn: # pylint: disable=too-few-public-methods @@ -59,9 +65,7 @@ def _check_rf_callback( _SklObjective = Optional[ - Union[ - str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] - ] + Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] ] @@ -95,10 +99,12 @@ def _objective_decorator( The training set from which the labels will be extracted using ``dmatrix.get_label()`` """ + def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]: """internal function""" labels = dmatrix.get_label() return func(labels, preds) + return inner @@ -109,19 +115,21 @@ def _metric_decorator(func: Callable) -> Metric: is compatible with :py:func:`train` """ + def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: y_true = dmatrix.get_label() return func.__name__, func(y_true, y_score) + return inner -__estimator_doc = ''' +__estimator_doc = """ n_estimators : int Number of gradient boosted trees. Equivalent to number of boosting rounds. -''' +""" -__model_doc = f''' +__model_doc = f""" max_depth : Optional[int] Maximum tree depth for base learners. max_leaves : @@ -146,10 +154,10 @@ __model_doc = f''' recommended to study this option from the parameters document :doc:`tree method ` n_jobs : Optional[int] - Number of parallel threads used to run xgboost. When used with other Scikit-Learn - algorithms like grid search, you may choose which algorithm to parallelize and - balance the threads. Creating thread contention will significantly slow down both - algorithms. + Number of parallel threads used to run xgboost. When used with other + Scikit-Learn algorithms like grid search, you may choose which algorithm to + parallelize and balance the threads. Creating thread contention will + significantly slow down both algorithms. gamma : Optional[float] (min_split_loss) Minimum loss reduction required to make a further partition on a leaf node of the tree. @@ -332,9 +340,9 @@ __model_doc = f''' \\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters passed via this argument will interact properly with scikit-learn. -''' +""" -__custom_obj_note = ''' +__custom_obj_note = """ .. note:: Custom objective function A custom objective function can be provided for the ``objective`` @@ -350,15 +358,16 @@ __custom_obj_note = ''' The value of the gradient for each sample point. hess: array_like of shape [n_samples] The value of the second derivative for each sample point -''' +""" def xgboost_model_doc( - header: str, items: List[str], + header: str, + items: List[str], extra_parameters: Optional[str] = None, - end_note: Optional[str] = None + end_note: Optional[str] = None, ) -> Callable[[Type], Type]: - '''Obtain documentation for Scikit-Learn wrappers + """Obtain documentation for Scikit-Learn wrappers Parameters ---------- @@ -372,29 +381,34 @@ def xgboost_model_doc( extra_parameters: str Document for class specific parameters, placed at the head. end_note: str - Extra notes put to the end. -''' + Extra notes put to the end.""" + def get_doc(item: str) -> str: - '''Return selected item''' - __doc = {'estimators': __estimator_doc, - 'model': __model_doc, - 'objective': __custom_obj_note} + """Return selected item""" + __doc = { + "estimators": __estimator_doc, + "model": __model_doc, + "objective": __custom_obj_note, + } return __doc[item] def adddoc(cls: Type) -> Type: - doc = [''' + doc = [ + """ Parameters ---------- -'''] +""" + ] if extra_parameters: doc.append(extra_parameters) doc.extend([get_doc(i) for i in items]) if end_note: doc.append(end_note) - full_doc = [header + '\n\n'] + full_doc = [header + "\n\n"] full_doc.extend(doc) - cls.__doc__ = ''.join(full_doc) + cls.__doc__ = "".join(full_doc) return cls + return adddoc @@ -416,9 +430,7 @@ def _wrap_evaluation_matrices( enable_categorical: bool, feature_types: Optional[FeatureTypes], ) -> Tuple[Any, List[Tuple[Any, str]]]: - """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. - - """ + """Convert array_like evaluation matrices into DMatrix. Perform validation on the way.""" train_dmatrix = create_dmatrix( data=X, label=y, @@ -439,8 +451,8 @@ def _wrap_evaluation_matrices( return [None] * n_validation if len(meta) != n_validation: raise ValueError( - f"{name}'s length does not equal `eval_set`'s length, " + - f"expecting {n_validation}, got {len(meta)}" + f"{name}'s length does not equal `eval_set`'s length, " + + f"expecting {n_validation}, got {len(meta)}" ) return meta @@ -459,11 +471,12 @@ def _wrap_evaluation_matrices( # Skip the duplicated entry. if all( ( - valid_X is X, valid_y is y, + valid_X is X, + valid_y is y, sample_weight_eval_set[i] is sample_weight, base_margin_eval_set[i] is base_margin, eval_group[i] is group, - eval_qid[i] is qid + eval_qid[i] is qid, ) ): evals.append(train_dmatrix) @@ -502,8 +515,10 @@ def _wrap_evaluation_matrices( return train_dmatrix, evals -@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", - ['estimators', 'model', 'objective']) +@xgboost_model_doc( + """Implementation of the Scikit-Learn API for XGBoost.""", + ["estimators", "model", "objective"], +) class XGBModel(XGBModelBase): # pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring def __init__( @@ -546,7 +561,7 @@ class XGBModel(XGBModelBase): eval_metric: Optional[Union[str, List[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, callbacks: Optional[List[TrainingCallback]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: if not SKLEARN_INSTALLED: raise ImportError( @@ -595,8 +610,8 @@ class XGBModel(XGBModelBase): self.kwargs = kwargs def _more_tags(self) -> Dict[str, bool]: - '''Tags used for scikit-learn data validation.''' - return {'allow_nan': True, 'no_validation': True} + """Tags used for scikit-learn data validation.""" + return {"allow_nan": True, "no_validation": True} def __sklearn_is_fitted__(self) -> bool: return hasattr(self, "_Booster") @@ -612,7 +627,8 @@ class XGBModel(XGBModelBase): """ if not self.__sklearn_is_fitted__(): from sklearn.exceptions import NotFittedError - raise NotFittedError('need to call fit or load_model beforehand') + + raise NotFittedError("need to call fit or load_model beforehand") return self._Booster def set_params(self, **params: Any) -> "XGBModel": @@ -640,7 +656,7 @@ class XGBModel(XGBModelBase): self.kwargs = {} self.kwargs[key] = value - if hasattr(self, '_Booster'): + if hasattr(self, "_Booster"): parameters = self.get_xgb_params() self.get_booster().set_param(parameters) @@ -662,9 +678,10 @@ class XGBModel(XGBModelBase): # if kwargs is a dict, update params accordingly if hasattr(self, "kwargs") and isinstance(self.kwargs, dict): params.update(self.kwargs) - if isinstance(params['random_state'], np.random.RandomState): - params['random_state'] = params['random_state'].randint( - np.iinfo(np.int32).max) + if isinstance(params["random_state"], np.random.RandomState): + params["random_state"] = params["random_state"].randint( + np.iinfo(np.int32).max + ) def parse_parameter(value: Any) -> Optional[Union[int, float, str]]: for t in (int, float, str): @@ -683,7 +700,7 @@ class XGBModel(XGBModelBase): while stack: obj = stack.pop() for k, v in obj.items(): - if k.endswith('_param'): + if k.endswith("_param"): for p_k, p_v in v.items(): internal[p_k] = p_v elif isinstance(v, dict): @@ -722,7 +739,7 @@ class XGBModel(XGBModelBase): return self.n_estimators def _get_type(self) -> str: - if not hasattr(self, '_estimator_type'): + if not hasattr(self, "_estimator_type"): raise TypeError( "`_estimator_type` undefined. " "Please use appropriate mixin to define estimator type." @@ -732,14 +749,14 @@ class XGBModel(XGBModelBase): def save_model(self, fname: Union[str, os.PathLike]) -> None: meta: Dict[str, Any] = {} for k, v in self.__dict__.items(): - if k == '_le': - meta['_le'] = self._le.to_json() + if k == "_le": + meta["_le"] = self._le.to_json() continue - if k == '_Booster': + if k == "_Booster": continue - if k == 'classes_': + if k == "classes_": # numpy array is not JSON serializable - meta['classes_'] = self.classes_.tolist() + meta["classes_"] = self.classes_.tolist() continue if k == "feature_types": # Use the `feature_types` attribute from booster instead. @@ -749,8 +766,10 @@ class XGBModel(XGBModelBase): json.dumps({k: v}) meta[k] = v except TypeError: - warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning) - meta['_estimator_type'] = self._get_type() + warnings.warn( + str(k) + " is not saved in Scikit-Learn meta.", UserWarning + ) + meta["_estimator_type"] = self._get_type() meta_str = json.dumps(meta) self.get_booster().set_attr(scikit_learn=meta_str) self.get_booster().save_model(fname) @@ -761,27 +780,25 @@ class XGBModel(XGBModelBase): def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: # pylint: disable=attribute-defined-outside-init - if not hasattr(self, '_Booster'): - self._Booster = Booster({'n_jobs': self.n_jobs}) + if not hasattr(self, "_Booster"): + self._Booster = Booster({"n_jobs": self.n_jobs}) self.get_booster().load_model(fname) - meta_str = self.get_booster().attr('scikit_learn') + meta_str = self.get_booster().attr("scikit_learn") if meta_str is None: # FIXME(jiaming): This doesn't have to be a problem as most of the needed # information like num_class and objective is in Learner class. - warnings.warn( - 'Loading a native XGBoost model with Scikit-Learn interface.' - ) + warnings.warn("Loading a native XGBoost model with Scikit-Learn interface.") return meta = json.loads(meta_str) states = {} for k, v in meta.items(): - if k == '_le': + if k == "_le": self._le = XGBoostLabelEncoder() self._le.from_json(v) continue # FIXME(jiaming): This can be removed once label encoder is gone since we can # generate it from `np.arange(self.n_classes_)` - if k == 'classes_': + if k == "classes_": self.classes_ = np.array(v) continue if k == "feature_types": @@ -907,7 +924,7 @@ class XGBModel(XGBModelBase): sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, feature_weights: Optional[ArrayLike] = None, - callbacks: Optional[Sequence[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None, ) -> "XGBModel": # pylint: disable=invalid-name,attribute-defined-outside-init """Fit gradient boosting model. @@ -963,54 +980,61 @@ class XGBModel(XGBModelBase): .. deprecated:: 1.6.0 Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead. """ - evals_result: TrainingCallback.EvalsLog = {} - train_dmatrix, evals = _wrap_evaluation_matrices( - missing=self.missing, - X=X, - y=y, - group=None, - qid=None, - sample_weight=sample_weight, - base_margin=base_margin, - feature_weights=feature_weights, - eval_set=eval_set, - sample_weight_eval_set=sample_weight_eval_set, - base_margin_eval_set=base_margin_eval_set, - eval_group=None, - eval_qid=None, - create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), - enable_categorical=self.enable_categorical, - feature_types=self.feature_types - ) - params = self.get_xgb_params() + with config_context(verbosity=self.verbosity): + evals_result: TrainingCallback.EvalsLog = {} + train_dmatrix, evals = _wrap_evaluation_matrices( + missing=self.missing, + X=X, + y=y, + group=None, + qid=None, + sample_weight=sample_weight, + base_margin=base_margin, + feature_weights=feature_weights, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + eval_group=None, + eval_qid=None, + create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + enable_categorical=self.enable_categorical, + feature_types=self.feature_types, + ) + params = self.get_xgb_params() - if callable(self.objective): - obj: Optional[ - Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] - ] = _objective_decorator(self.objective) - params["objective"] = "reg:squarederror" - else: - obj = None + if callable(self.objective): + obj: Optional[ + Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] + ] = _objective_decorator(self.objective) + params["objective"] = "reg:squarederror" + else: + obj = None - model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds, callbacks - ) - self._Booster = train( - params, - train_dmatrix, - self.get_num_boosting_rounds(), - evals=evals, - early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, - obj=obj, - custom_metric=metric, - verbose_eval=verbose, - xgb_model=model, - callbacks=callbacks, - ) + ( + model, + metric, + params, + early_stopping_rounds, + callbacks, + ) = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks + ) + self._Booster = train( + params, + train_dmatrix, + self.get_num_boosting_rounds(), + evals=evals, + early_stopping_rounds=early_stopping_rounds, + evals_result=evals_result, + obj=obj, + custom_metric=metric, + verbose_eval=verbose, + xgb_model=model, + callbacks=callbacks, + ) - self._set_evaluation_result(evals_result) - return self + self._set_evaluation_result(evals_result) + return self def _can_use_inplace_predict(self) -> bool: # When predictor is explicitly set, using `inplace_predict` might result into @@ -1025,7 +1049,7 @@ class XGBModel(XGBModelBase): def _get_iteration_range( self, iteration_range: Optional[Tuple[int, int]] ) -> Tuple[int, int]: - if (iteration_range is None or iteration_range[1] == 0): + if iteration_range is None or iteration_range[1] == 0: # Use best_iteration if defined. try: iteration_range = (0, self.best_iteration + 1) @@ -1067,8 +1091,8 @@ class XGBModel(XGBModelBase): iteration_range : Specifies which layer of trees are used in prediction. For example, if a random forest is trained with 100 rounds. Specifying ``iteration_range=(10, - 20)``, then only the forests built during [10, 20) (half open set) rounds are - used in this prediction. + 20)``, then only the forests built during [10, 20) (half open set) rounds + are used in this prediction. .. versionadded:: 1.4.0 @@ -1077,47 +1101,50 @@ class XGBModel(XGBModelBase): prediction """ - iteration_range = _convert_ntree_limit( - self.get_booster(), ntree_limit, iteration_range - ) - iteration_range = self._get_iteration_range(iteration_range) - if self._can_use_inplace_predict(): - try: - predts = self.get_booster().inplace_predict( - data=X, - iteration_range=iteration_range, - predict_type="margin" if output_margin else "value", - missing=self.missing, - base_margin=base_margin, - validate_features=validate_features, - ) - if _is_cupy_array(predts): - import cupy # pylint: disable=import-error - predts = cupy.asnumpy(predts) # ensure numpy array is used. - return predts - except TypeError: - # coo, csc, dt - pass + with config_context(verbosity=self.verbosity): + iteration_range = _convert_ntree_limit( + self.get_booster(), ntree_limit, iteration_range + ) + iteration_range = self._get_iteration_range(iteration_range) + if self._can_use_inplace_predict(): + try: + predts = self.get_booster().inplace_predict( + data=X, + iteration_range=iteration_range, + predict_type="margin" if output_margin else "value", + missing=self.missing, + base_margin=base_margin, + validate_features=validate_features, + ) + if _is_cupy_array(predts): + import cupy # pylint: disable=import-error - test = DMatrix( - X, - base_margin=base_margin, - missing=self.missing, - nthread=self.n_jobs, - feature_types=self.feature_types, - enable_categorical=self.enable_categorical - ) - return self.get_booster().predict( - data=test, - iteration_range=iteration_range, - output_margin=output_margin, - validate_features=validate_features, - ) + predts = cupy.asnumpy(predts) # ensure numpy array is used. + return predts + except TypeError: + # coo, csc, dt + pass + + test = DMatrix( + X, + base_margin=base_margin, + missing=self.missing, + nthread=self.n_jobs, + feature_types=self.feature_types, + enable_categorical=self.enable_categorical, + ) + return self.get_booster().predict( + data=test, + iteration_range=iteration_range, + output_margin=output_margin, + validate_features=validate_features, + ) def apply( - self, X: ArrayLike, + self, + X: ArrayLike, ntree_limit: int = 0, - iteration_range: Optional[Tuple[int, int]] = None + iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: """Return the predicted leaf every tree for each sample. If the model is trained with early stopping, then `best_iteration` is used automatically. @@ -1141,18 +1168,20 @@ class XGBModel(XGBModelBase): ``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering. """ - iteration_range = _convert_ntree_limit( - self.get_booster(), ntree_limit, iteration_range - ) - iteration_range = self._get_iteration_range(iteration_range) - test_dmatrix = DMatrix( - X, missing=self.missing, feature_types=self.feature_types, nthread=self.n_jobs - ) - return self.get_booster().predict( - test_dmatrix, - pred_leaf=True, - iteration_range=iteration_range - ) + with config_context(verbosity=self.verbosity): + iteration_range = _convert_ntree_limit( + self.get_booster(), ntree_limit, iteration_range + ) + iteration_range = self._get_iteration_range(iteration_range) + test_dmatrix = DMatrix( + X, + missing=self.missing, + feature_types=self.feature_types, + nthread=self.n_jobs, + ) + return self.get_booster().predict( + test_dmatrix, pred_leaf=True, iteration_range=iteration_range + ) def evals_result(self) -> Dict[str, Dict[str, List[float]]]: """Return the evaluation results. @@ -1208,13 +1237,13 @@ class XGBModel(XGBModelBase): return getattr(booster, attr) except AttributeError as e: raise AttributeError( - f'`{attr}` in only defined when early stopping is used.' + f"`{attr}` in only defined when early stopping is used." ) from e @property def best_score(self) -> float: """The best score obtained by early stopping.""" - return float(self._early_stopping_attr('best_score')) + return float(self._early_stopping_attr("best_score")) @property def best_iteration(self) -> int: @@ -1222,11 +1251,11 @@ class XGBModel(XGBModelBase): for instance if the best iteration is the first round, then best_iteration is 0. """ - return int(self._early_stopping_attr('best_iteration')) + return int(self._early_stopping_attr("best_iteration")) @property def best_ntree_limit(self) -> int: - return int(self._early_stopping_attr('best_ntree_limit')) + return int(self._early_stopping_attr("best_ntree_limit")) @property def feature_importances_(self) -> np.ndarray: @@ -1243,6 +1272,7 @@ class XGBModel(XGBModelBase): def dft() -> str: return "weight" if self.booster == "gblinear" else "gain" + score = b.get_score( importance_type=self.importance_type if self.importance_type else dft() ) @@ -1251,7 +1281,7 @@ class XGBModel(XGBModelBase): else: feature_names = b.feature_names # gblinear returns all features so the `get` in next line is only for gbtree. - all_features = [score.get(f, 0.) for f in feature_names] + all_features = [score.get(f, 0.0) for f in feature_names] all_features_arr = np.array(all_features, dtype=np.float32) total = all_features_arr.sum() if total == 0: @@ -1273,14 +1303,14 @@ class XGBModel(XGBModelBase): ------- coef_ : array of shape ``[n_features]`` or ``[n_classes, n_features]`` """ - if self.get_params()['booster'] != 'gblinear': + if self.get_params()["booster"] != "gblinear": raise AttributeError( f"Coefficients are not defined for Booster type {self.booster}" ) b = self.get_booster() - coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight']) + coef = np.array(json.loads(b.get_dump(dump_format="json")[0])["weight"]) # Logic for multiclass classification - n_classes = getattr(self, 'n_classes_', None) + n_classes = getattr(self, "n_classes_", None) if n_classes is not None: if n_classes > 2: assert len(coef.shape) == 1 @@ -1303,12 +1333,12 @@ class XGBModel(XGBModelBase): ------- intercept_ : array of shape ``(1,)`` or ``[n_classes]`` """ - if self.get_params()['booster'] != 'gblinear': + if self.get_params()["booster"] != "gblinear": raise AttributeError( f"Intercept (bias) is not defined for Booster type {self.booster}" ) b = self.get_booster() - return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) + return np.array(json.loads(b.get_dump(dump_format="json")[0])["bias"]) PredtT = TypeVar("PredtT", bound=np.ndarray) @@ -1334,10 +1364,12 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> @xgboost_model_doc( "Implementation of the scikit-learn API for XGBoost classification.", - ['model', 'objective'], extra_parameters=''' + ["model", "objective"], + extra_parameters=""" n_estimators : int Number of boosting rounds. -''') +""", +) class XGBClassifier(XGBModel, XGBClassifierBase): # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes @_deprecate_positional_args @@ -1346,7 +1378,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): *, objective: _SklObjective = "binary:logistic", use_label_encoder: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> None: # must match the parameters for `get_params` self.use_label_encoder = use_label_encoder @@ -1372,99 +1404,106 @@ class XGBClassifier(XGBModel, XGBClassifierBase): sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, feature_weights: Optional[ArrayLike] = None, - callbacks: Optional[Sequence[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None, ) -> "XGBClassifier": # pylint: disable = attribute-defined-outside-init,too-many-statements - evals_result: TrainingCallback.EvalsLog = {} + with config_context(verbosity=self.verbosity): + evals_result: TrainingCallback.EvalsLog = {} - if _is_cudf_df(y) or _is_cudf_ser(y): - import cupy as cp # pylint: disable=E0401 + if _is_cudf_df(y) or _is_cudf_ser(y): + import cupy as cp # pylint: disable=E0401 - self.classes_ = cp.unique(y.values) - self.n_classes_ = len(self.classes_) - expected_classes = cp.arange(self.n_classes_) - elif _is_cupy_array(y): - import cupy as cp # pylint: disable=E0401 + self.classes_ = cp.unique(y.values) + self.n_classes_ = len(self.classes_) + expected_classes = cp.arange(self.n_classes_) + elif _is_cupy_array(y): + import cupy as cp # pylint: disable=E0401 - self.classes_ = cp.unique(y) - self.n_classes_ = len(self.classes_) - expected_classes = cp.arange(self.n_classes_) - else: - self.classes_ = np.unique(np.asarray(y)) - self.n_classes_ = len(self.classes_) - expected_classes = np.arange(self.n_classes_) - if ( - self.classes_.shape != expected_classes.shape - or not (self.classes_ == expected_classes).all() - ): - raise ValueError( - f"Invalid classes inferred from unique values of `y`. " - f"Expected: {expected_classes}, got {self.classes_}" + self.classes_ = cp.unique(y) + self.n_classes_ = len(self.classes_) + expected_classes = cp.arange(self.n_classes_) + else: + self.classes_ = np.unique(np.asarray(y)) + self.n_classes_ = len(self.classes_) + expected_classes = np.arange(self.n_classes_) + if ( + self.classes_.shape != expected_classes.shape + or not (self.classes_ == expected_classes).all() + ): + raise ValueError( + f"Invalid classes inferred from unique values of `y`. " + f"Expected: {expected_classes}, got {self.classes_}" + ) + + params = self.get_xgb_params() + + if callable(self.objective): + obj: Optional[ + Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] + ] = _objective_decorator(self.objective) + # Use default value. Is it really not used ? + params["objective"] = "binary:logistic" + else: + obj = None + + if self.n_classes_ > 2: + # Switch to using a multiclass objective in the underlying XGB instance + if params.get("objective", None) != "multi:softmax": + params["objective"] = "multi:softprob" + params["num_class"] = self.n_classes_ + + ( + model, + metric, + params, + early_stopping_rounds, + callbacks, + ) = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks + ) + train_dmatrix, evals = _wrap_evaluation_matrices( + missing=self.missing, + X=X, + y=y, + group=None, + qid=None, + sample_weight=sample_weight, + base_margin=base_margin, + feature_weights=feature_weights, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + eval_group=None, + eval_qid=None, + create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) - params = self.get_xgb_params() + self._Booster = train( + params, + train_dmatrix, + self.get_num_boosting_rounds(), + evals=evals, + early_stopping_rounds=early_stopping_rounds, + evals_result=evals_result, + obj=obj, + custom_metric=metric, + verbose_eval=verbose, + xgb_model=model, + callbacks=callbacks, + ) - if callable(self.objective): - obj: Optional[ - Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] - ] = _objective_decorator(self.objective) - # Use default value. Is it really not used ? - params["objective"] = "binary:logistic" - else: - obj = None + if not callable(self.objective): + self.objective = params["objective"] - if self.n_classes_ > 2: - # Switch to using a multiclass objective in the underlying XGB instance - if params.get("objective", None) != "multi:softmax": - params["objective"] = "multi:softprob" - params["num_class"] = self.n_classes_ - - model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds, callbacks - ) - train_dmatrix, evals = _wrap_evaluation_matrices( - missing=self.missing, - X=X, - y=y, - group=None, - qid=None, - sample_weight=sample_weight, - base_margin=base_margin, - feature_weights=feature_weights, - eval_set=eval_set, - sample_weight_eval_set=sample_weight_eval_set, - base_margin_eval_set=base_margin_eval_set, - eval_group=None, - eval_qid=None, - create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), - enable_categorical=self.enable_categorical, - feature_types=self.feature_types, - ) - - self._Booster = train( - params, - train_dmatrix, - self.get_num_boosting_rounds(), - evals=evals, - early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, - obj=obj, - custom_metric=metric, - verbose_eval=verbose, - xgb_model=model, - callbacks=callbacks, - ) - - if not callable(self.objective): - self.objective = params["objective"] - - self._set_evaluation_result(evals_result) - return self + self._set_evaluation_result(evals_result) + return self assert XGBModel.fit.__doc__ is not None fit.__doc__ = XGBModel.fit.__doc__.replace( - 'Fit gradient boosting model', - 'Fit gradient boosting classifier', 1) + "Fit gradient boosting model", "Fit gradient boosting classifier", 1 + ) def predict( self, @@ -1475,35 +1514,36 @@ class XGBClassifier(XGBModel, XGBClassifierBase): base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: - class_probs = super().predict( - X=X, - output_margin=output_margin, - ntree_limit=ntree_limit, - validate_features=validate_features, - base_margin=base_margin, - iteration_range=iteration_range, - ) - if output_margin: - # If output_margin is active, simply return the scores - return class_probs + with config_context(verbosity=self.verbosity): + class_probs = super().predict( + X=X, + output_margin=output_margin, + ntree_limit=ntree_limit, + validate_features=validate_features, + base_margin=base_margin, + iteration_range=iteration_range, + ) + if output_margin: + # If output_margin is active, simply return the scores + return class_probs - if len(class_probs.shape) > 1 and self.n_classes_ != 2: - # multi-class, turns softprob into softmax - column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore - elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1: - # multi-label - column_indexes = np.zeros(class_probs.shape) - column_indexes[class_probs > 0.5] = 1 - elif self.objective == "multi:softmax": - return class_probs.astype(np.int32) - else: - # turns soft logit into class label - column_indexes = np.repeat(0, class_probs.shape[0]) - column_indexes[class_probs > 0.5] = 1 + if len(class_probs.shape) > 1 and self.n_classes_ != 2: + # multi-class, turns softprob into softmax + column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore + elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1: + # multi-label + column_indexes = np.zeros(class_probs.shape) + column_indexes[class_probs > 0.5] = 1 + elif self.objective == "multi:softmax": + return class_probs.astype(np.int32) + else: + # turns soft logit into class label + column_indexes = np.repeat(0, class_probs.shape[0]) + column_indexes[class_probs > 0.5] = 1 - if hasattr(self, '_le'): - return self._le.inverse_transform(column_indexes) - return column_indexes + if hasattr(self, "_le"): + return self._le.inverse_transform(column_indexes) + return column_indexes def predict_proba( self, @@ -1513,7 +1553,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: - """ Predict the probability of each `X` example being of a given class. + """Predict the probability of each `X` example being of a given class. .. note:: This function is only thread safe for `gbtree` and `dart`. @@ -1552,7 +1592,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, - output_margin=True + output_margin=True, ) class_prob = softmax(raw_predt, axis=1) return class_prob @@ -1561,106 +1601,40 @@ class XGBClassifier(XGBModel, XGBClassifierBase): ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, - iteration_range=iteration_range + iteration_range=iteration_range, ) # If model is loaded from a raw booster there's no `n_classes_` - return _cls_predict_proba(getattr(self, "n_classes_", 0), class_probs, np.vstack) + return _cls_predict_proba( + getattr(self, "n_classes_", 0), class_probs, np.vstack + ) @xgboost_model_doc( "scikit-learn API for XGBoost random forest classification.", - ['model', 'objective'], - extra_parameters=''' + ["model", "objective"], + extra_parameters=""" n_estimators : int Number of trees in random forest to fit. -''') +""", +) class XGBRFClassifier(XGBClassifier): # pylint: disable=missing-docstring @_deprecate_positional_args def __init__( - self, *, + self, + *, learning_rate: float = 1.0, subsample: float = 0.8, colsample_bynode: float = 0.8, reg_lambda: float = 1e-5, - **kwargs: Any + **kwargs: Any, ): - super().__init__(learning_rate=learning_rate, - subsample=subsample, - colsample_bynode=colsample_bynode, - reg_lambda=reg_lambda, - **kwargs) - _check_rf_callback(self.early_stopping_rounds, self.callbacks) - - def get_xgb_params(self) -> Dict[str, Any]: - params = super().get_xgb_params() - params['num_parallel_tree'] = self.n_estimators - return params - - def get_num_boosting_rounds(self) -> int: - return 1 - - # pylint: disable=unused-argument - @_deprecate_positional_args - def fit( - self, - X: ArrayLike, - y: ArrayLike, - *, - sample_weight: Optional[ArrayLike] = None, - base_margin: Optional[ArrayLike] = None, - eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, - eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, - early_stopping_rounds: Optional[int] = None, - verbose: Optional[Union[bool, int]] = True, - xgb_model: Optional[Union[Booster, str, XGBModel]] = None, - sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, - base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, - feature_weights: Optional[ArrayLike] = None, - callbacks: Optional[Sequence[TrainingCallback]] = None - ) -> "XGBRFClassifier": - args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} - _check_rf_callback(early_stopping_rounds, callbacks) - super().fit(**args) - return self - - -@xgboost_model_doc( - "Implementation of the scikit-learn API for XGBoost regression.", - ['estimators', 'model', 'objective']) -class XGBRegressor(XGBModel, XGBRegressorBase): - # pylint: disable=missing-docstring - @_deprecate_positional_args - def __init__( - self, *, objective: _SklObjective = "reg:squarederror", **kwargs: Any - ) -> None: - super().__init__(objective=objective, **kwargs) - - -@xgboost_model_doc( - "scikit-learn API for XGBoost random forest regression.", - ['model', 'objective'], extra_parameters=''' - n_estimators : int - Number of trees in random forest to fit. -''') -class XGBRFRegressor(XGBRegressor): - # pylint: disable=missing-docstring - @_deprecate_positional_args - def __init__( - self, - *, - learning_rate: float = 1.0, - subsample: float = 0.8, - colsample_bynode: float = 0.8, - reg_lambda: float = 1e-5, - **kwargs: Any - ) -> None: super().__init__( learning_rate=learning_rate, subsample=subsample, colsample_bynode=colsample_bynode, reg_lambda=reg_lambda, - **kwargs + **kwargs, ) _check_rf_callback(self.early_stopping_rounds, self.callbacks) @@ -1689,7 +1663,82 @@ class XGBRFRegressor(XGBRegressor): sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, feature_weights: Optional[ArrayLike] = None, - callbacks: Optional[Sequence[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None, + ) -> "XGBRFClassifier": + args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} + _check_rf_callback(early_stopping_rounds, callbacks) + super().fit(**args) + return self + + +@xgboost_model_doc( + "Implementation of the scikit-learn API for XGBoost regression.", + ["estimators", "model", "objective"], +) +class XGBRegressor(XGBModel, XGBRegressorBase): + # pylint: disable=missing-docstring + @_deprecate_positional_args + def __init__( + self, *, objective: _SklObjective = "reg:squarederror", **kwargs: Any + ) -> None: + super().__init__(objective=objective, **kwargs) + + +@xgboost_model_doc( + "scikit-learn API for XGBoost random forest regression.", + ["model", "objective"], + extra_parameters=""" + n_estimators : int + Number of trees in random forest to fit. +""", +) +class XGBRFRegressor(XGBRegressor): + # pylint: disable=missing-docstring + @_deprecate_positional_args + def __init__( + self, + *, + learning_rate: float = 1.0, + subsample: float = 0.8, + colsample_bynode: float = 0.8, + reg_lambda: float = 1e-5, + **kwargs: Any, + ) -> None: + super().__init__( + learning_rate=learning_rate, + subsample=subsample, + colsample_bynode=colsample_bynode, + reg_lambda=reg_lambda, + **kwargs, + ) + _check_rf_callback(self.early_stopping_rounds, self.callbacks) + + def get_xgb_params(self) -> Dict[str, Any]: + params = super().get_xgb_params() + params["num_parallel_tree"] = self.n_estimators + return params + + def get_num_boosting_rounds(self) -> int: + return 1 + + # pylint: disable=unused-argument + @_deprecate_positional_args + def fit( + self, + X: ArrayLike, + y: ArrayLike, + *, + sample_weight: Optional[ArrayLike] = None, + base_margin: Optional[ArrayLike] = None, + eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None, + eval_metric: Optional[Union[str, Sequence[str], Metric]] = None, + early_stopping_rounds: Optional[int] = None, + verbose: Optional[Union[bool, int]] = True, + xgb_model: Optional[Union[Booster, str, XGBModel]] = None, + sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, + base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, + feature_weights: Optional[ArrayLike] = None, + callbacks: Optional[Sequence[TrainingCallback]] = None, ) -> "XGBRFRegressor": args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} _check_rf_callback(early_stopping_rounds, callbacks) @@ -1698,9 +1747,9 @@ class XGBRFRegressor(XGBRegressor): @xgboost_model_doc( - 'Implementation of the Scikit-Learn API for XGBoost Ranking.', - ['estimators', 'model'], - end_note=''' + "Implementation of the Scikit-Learn API for XGBoost Ranking.", + ["estimators", "model"], + end_note=""" .. note:: A custom objective function is currently not supported by XGBRanker. @@ -1737,7 +1786,8 @@ class XGBRFRegressor(XGBRegressor): then your group array should be ``[3, 4]``. Sometimes using query id (`qid`) instead of group can be more convenient. -''') +""", +) class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @_deprecate_positional_args @@ -1768,7 +1818,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn): sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None, base_margin_eval_set: Optional[Sequence[ArrayLike]] = None, feature_weights: Optional[ArrayLike] = None, - callbacks: Optional[Sequence[TrainingCallback]] = None + callbacks: Optional[Sequence[TrainingCallback]] = None, ) -> "XGBRanker": # pylint: disable = attribute-defined-outside-init,arguments-differ """Fit gradient boosting ranker @@ -1853,56 +1903,65 @@ class XGBRanker(XGBModel, XGBRankerMixIn): Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead. """ # check if group information is provided - if group is None and qid is None: - raise ValueError("group or qid is required for ranking task") + with config_context(verbosity=self.verbosity): + if group is None and qid is None: + raise ValueError("group or qid is required for ranking task") - if eval_set is not None: - if eval_group is None and eval_qid is None: - raise ValueError( - "eval_group or eval_qid is required if eval_set is not None") - train_dmatrix, evals = _wrap_evaluation_matrices( - missing=self.missing, - X=X, - y=y, - group=group, - qid=qid, - sample_weight=sample_weight, - base_margin=base_margin, - feature_weights=feature_weights, - eval_set=eval_set, - sample_weight_eval_set=sample_weight_eval_set, - base_margin_eval_set=base_margin_eval_set, - eval_group=eval_group, - eval_qid=eval_qid, - create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), - enable_categorical=self.enable_categorical, - feature_types=self.feature_types, - ) - - evals_result: TrainingCallback.EvalsLog = {} - params = self.get_xgb_params() - - model, metric, params, early_stopping_rounds, callbacks = self._configure_fit( - xgb_model, eval_metric, params, early_stopping_rounds, callbacks - ) - if callable(metric): - raise ValueError( - 'Custom evaluation metric is not yet supported for XGBRanker.' + if eval_set is not None: + if eval_group is None and eval_qid is None: + raise ValueError( + "eval_group or eval_qid is required if eval_set is not None" + ) + train_dmatrix, evals = _wrap_evaluation_matrices( + missing=self.missing, + X=X, + y=y, + group=group, + qid=qid, + sample_weight=sample_weight, + base_margin=base_margin, + feature_weights=feature_weights, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + eval_group=eval_group, + eval_qid=eval_qid, + create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) - self._Booster = train( - params, - train_dmatrix, - self.get_num_boosting_rounds(), - early_stopping_rounds=early_stopping_rounds, - evals=evals, - evals_result=evals_result, - custom_metric=metric, - verbose_eval=verbose, xgb_model=model, - callbacks=callbacks - ) + evals_result: TrainingCallback.EvalsLog = {} + params = self.get_xgb_params() - self.objective = params["objective"] + ( + model, + metric, + params, + early_stopping_rounds, + callbacks, + ) = self._configure_fit( + xgb_model, eval_metric, params, early_stopping_rounds, callbacks + ) + if callable(metric): + raise ValueError( + "Custom evaluation metric is not yet supported for XGBRanker." + ) - self._set_evaluation_result(evals_result) - return self + self._Booster = train( + params, + train_dmatrix, + self.get_num_boosting_rounds(), + early_stopping_rounds=early_stopping_rounds, + evals=evals, + evals_result=evals_result, + custom_metric=metric, + verbose_eval=verbose, + xgb_model=model, + callbacks=callbacks, + ) + + self.objective = params["objective"] + + self._set_evaluation_result(evals_result) + return self diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index cfc4b8598..8d08003cb 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -113,7 +113,9 @@ if __name__ == "__main__": run_formatter(path) for path in [ "python-package/xgboost/dask.py", + "python-package/xgboost/sklearn.py", "python-package/xgboost/spark", + "tests/python/test_config.py", "tests/python/test_spark/test_data.py", "tests/python-gpu/test_gpu_spark/test_data.py", "tests/ci_build/lint_python.py", diff --git a/tests/python/test_config.py b/tests/python/test_config.py index 87a544e9c..01b5c2d99 100644 --- a/tests/python/test_config.py +++ b/tests/python/test_config.py @@ -1,13 +1,15 @@ -# -*- coding: utf-8 -*- -import xgboost as xgb +import multiprocessing +from concurrent.futures import ThreadPoolExecutor + import pytest -import testing as tm + +import xgboost as xgb -@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3]) +@pytest.mark.parametrize("verbosity_level", [0, 1, 2, 3]) def test_global_config_verbosity(verbosity_level): def get_current_verbosity(): - return xgb.get_config()['verbosity'] + return xgb.get_config()["verbosity"] old_verbosity = get_current_verbosity() with xgb.config_context(verbosity=verbosity_level): @@ -16,13 +18,48 @@ def test_global_config_verbosity(verbosity_level): assert old_verbosity == get_current_verbosity() -@pytest.mark.parametrize('use_rmm', [False, True]) +@pytest.mark.parametrize("use_rmm", [False, True]) def test_global_config_use_rmm(use_rmm): def get_current_use_rmm_flag(): - return xgb.get_config()['use_rmm'] + return xgb.get_config()["use_rmm"] old_use_rmm_flag = get_current_use_rmm_flag() with xgb.config_context(use_rmm=use_rmm): new_use_rmm_flag = get_current_use_rmm_flag() assert new_use_rmm_flag == use_rmm assert old_use_rmm_flag == get_current_use_rmm_flag() + + +def test_nested_config(): + with xgb.config_context(verbosity=3): + assert xgb.get_config()["verbosity"] == 3 + with xgb.config_context(verbosity=2): + assert xgb.get_config()["verbosity"] == 2 + with xgb.config_context(verbosity=1): + assert xgb.get_config()["verbosity"] == 1 + assert xgb.get_config()["verbosity"] == 2 + assert xgb.get_config()["verbosity"] == 3 + + with xgb.config_context(verbosity=3): + assert xgb.get_config()["verbosity"] == 3 + with xgb.config_context(verbosity=None): + assert xgb.get_config()["verbosity"] == 3 # None has no effect + + verbosity = xgb.get_config()["verbosity"] + xgb.set_config(verbosity=2) + assert xgb.get_config()["verbosity"] == 2 + with xgb.config_context(verbosity=3): + assert xgb.get_config()["verbosity"] == 3 + xgb.set_config(verbosity=verbosity) # reset + + +def test_thread_safty(): + n_threads = multiprocessing.cpu_count() + futures = [] + with ThreadPoolExecutor(max_workers=n_threads) as executor: + for i in range(256): + f = executor.submit(test_nested_config) + futures.append(f) + + for f in futures: + f.result()