diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ffcaa9e77..f134c0399 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1219,7 +1219,9 @@ class Booster(object): elif isinstance(params, STRING_TYPES) and value is not None: params = [(params, value)] for key, val in params: - _check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val)))) + if val is not None: + _check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), + c_str(str(val)))) def update(self, dtrain, iteration, fobj=None): """Update for one iteration, with objective function calculated diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b1e0fafdc..3b5c8ff59 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -30,7 +30,7 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat from .core import DMatrix, Booster, _expect from .training import train as worker_train from .tracker import RabitTracker -from .sklearn import XGBModel, XGBClassifierBase +from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc # Current status is considered as initial support, many features are # not properly supported yet. @@ -580,13 +580,10 @@ class DaskScikitLearnBase(XGBModel): def client(self, clt): self._client = clt - +@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", + ['estimators', 'model']) class DaskXGBRegressor(DaskScikitLearnBase): # pylint: disable=missing-docstring - __doc__ = ('Implementation of the scikit-learn API for XGBoost ' + - 'regression. \n\n') + '\n'.join( - XGBModel.__doc__.split('\n')[2:]) - def fit(self, X, y, @@ -616,12 +613,13 @@ class DaskXGBRegressor(DaskScikitLearnBase): return pred_probs +@xgboost_model_doc( + 'Implementation of the scikit-learn API for XGBoost classification.', + ['estimators', 'model'] +) class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): # pylint: disable=missing-docstring _client = None - __doc__ = ('Implementation of the scikit-learn API for XGBoost ' + - 'classification.\n\n') + '\n'.join( - XGBModel.__doc__.split('\n')[2:]) def fit(self, X, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index f46601d8e..4e682fcb8 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1,14 +1,15 @@ # coding: utf-8 # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302 """Scikit-Learn Wrapper interface for XGBoost.""" +import copy import warnings import json import numpy as np from .core import Booster, DMatrix, XGBoostError from .training import train -# Do not use class names on scikit-learn directly. -# Re-define the classes on .compat to guarantee the behavior without scikit-learn +# Do not use class names on scikit-learn directly. Re-define the classes on +# .compat to guarantee the behavior without scikit-learn from .compat import (SKLEARN_INSTALLED, XGBModelBase, XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder) @@ -48,18 +49,17 @@ def _objective_decorator(func): return inner -class XGBModel(XGBModelBase): - # pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name - """Implementation of the Scikit-Learn API for XGBoost. +__estimator_doc = ''' + n_estimators : int + Number of gradient boosted trees. Equivalent to number of boosting + rounds. +''' - Parameters - ---------- +__model_doc = ''' max_depth : int Maximum tree depth for base learners. learning_rate : float Boosting learning rate (xgb's "eta") - n_estimators : int - Number of trees to fit. verbosity : int The degree of verbosity. Valid values are 0 (silent) - 3 (debug). objective : string or callable @@ -75,7 +75,8 @@ class XGBModel(XGBModelBase): n_jobs : int Number of parallel threads used to run xgboost. gamma : float - Minimum loss reduction required to make a further partition on a leaf node of the tree. + Minimum loss reduction required to make a further partition on a leaf + node of the tree. min_child_weight : int Minimum sum of instance weight(hessian) needed in a child. max_delta_step : int @@ -112,17 +113,21 @@ class XGBModel(XGBModelBase): importance_type: string, default "gain" The feature importance type for the feature_importances\\_ property: either "gain", "weight", "cover", "total_gain" or "total_cover". + \\*\\*kwargs : dict, optional - Keyword arguments for XGBoost Booster object. Full documentation of parameters can - be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. - Attempting to set a parameter via the constructor args and \\*\\*kwargs dict simultaneously - will result in a TypeError. + Keyword arguments for XGBoost Booster object. Full documentation of + parameters can be found here: + https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. + Attempting to set a parameter via the constructor args and \\*\\*kwargs + dict simultaneously will result in a TypeError. .. note:: \\*\\*kwargs unsupported by scikit-learn - \\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters - passed via this argument will interact properly with scikit-learn. + \\*\\*kwargs is unsupported by scikit-learn. We do not guarantee + that parameters passed via this argument will interact properly + with scikit-learn. ''' +__custom_obj_note = ''' Note ---- A custom objective function can be provided for the ``objective`` @@ -138,25 +143,72 @@ class XGBModel(XGBModelBase): The value of the gradient for each sample point. hess: array_like of shape [n_samples] The value of the second derivative for each sample point +''' - """ - def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, - verbosity=1, objective="reg:squarederror", - booster='gbtree', tree_method='auto', n_jobs=1, gamma=0, - min_child_weight=1, max_delta_step=0, subsample=1, - colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, - reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, - random_state=0, missing=None, num_parallel_tree=1, - importance_type="gain", **kwargs): +def xgboost_model_doc(header, items, extra_parameters=None, end_note=None): + '''Obtain documentation for Scikit-Learn wrappers + + Parameters + ---------- + header: str + An introducion to the class. + items : list + A list of commom doc items. Available items are: + - estimators: the meaning of n_estimators + - model: All the other parameters + - objective: note for customized objective + extra_parameters: str + Document for class specific parameters, placed at the head. + end_note: str + Extra notes put to the end. +''' + def get_doc(item): + '''Return selected item''' + __doc = {'estimators': __estimator_doc, + 'model': __model_doc, + 'objective': __custom_obj_note} + return __doc[item] + + def adddoc(cls): + doc = [''' +Parameters +---------- +'''] + if extra_parameters: + doc.append(extra_parameters) + doc.extend([get_doc(i) for i in items]) + if end_note: + doc.append(end_note) + full_doc = [header + '\n\n'] + full_doc.extend(doc) + cls.__doc__ = ''.join(full_doc) + return cls + return adddoc + + +@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", + ['estimators', 'model', 'objective']) +class XGBModel(XGBModelBase): + # pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring + def __init__(self, max_depth=None, learning_rate=None, n_estimators=100, + verbosity=None, objective=None, booster=None, + tree_method=None, n_jobs=None, gamma=None, + min_child_weight=None, max_delta_step=None, subsample=None, + colsample_bytree=None, colsample_bylevel=None, + colsample_bynode=None, reg_alpha=None, reg_lambda=None, + scale_pos_weight=None, base_score=None, random_state=None, + missing=None, num_parallel_tree=None, importance_type="gain", + gpu_id=None, **kwargs): if not SKLEARN_INSTALLED: raise XGBoostError( 'sklearn needs to be installed in order to use this module') + self.n_estimators = n_estimators + self.objective = objective + self.max_depth = max_depth self.learning_rate = learning_rate - self.n_estimators = n_estimators self.verbosity = verbosity - self.objective = objective self.booster = booster self.tree_method = tree_method self.gamma = gamma @@ -176,6 +228,7 @@ class XGBModel(XGBModelBase): self._Booster = None self.random_state = random_state self.n_jobs = n_jobs + self.gpu_id = gpu_id self.importance_type = importance_type def __setstate__(self, state): @@ -201,18 +254,22 @@ class XGBModel(XGBModelBase): return self._Booster def set_params(self, **params): - """Set the parameters of this estimator. - Modification of the sklearn method to allow unknown kwargs. This allows using - the full range of xgboost parameters that are not defined as member variables - in sklearn grid search. + """Set the parameters of this estimator. Modification of the sklearn method to + allow unknown kwargs. This allows using the full range of xgboost + parameters that are not defined as member variables in sklearn grid + search. + Returns ------- self + """ if not params: # Simple optimization to gain speed (inspect is slow) return self + # this concatenates kwargs into paraemters, enabling `get_params` for + # obtaining parameters from keyword paraemters. for key, value in params.items(): if hasattr(self, key): setattr(self, key, value) @@ -221,16 +278,26 @@ class XGBModel(XGBModelBase): return self - def get_params(self, deep=False): + def get_params(self, deep=True): + # pylint: disable=attribute-defined-outside-init """Get parameters.""" - params = super(XGBModel, self).get_params(deep=deep) + # Based on: https://stackoverflow.com/questions/59248211 + # The basic flow in `get_params` is: + # 0. Return parameters in subclass first, by using inspect. + # 1. Return parameters in `XGBModel` (the base class). + # 2. Return whatever in `**kwargs`. + # 3. Merge them. + params = super().get_params(deep) + if hasattr(self, '__copy__'): + warnings.warn('Calling __copy__ on Scikit-Learn wrapper, ' + + 'which may disable data cache and result in ' + + 'lower performance.') + cp = copy.copy(self) + cp.__class__ = cp.__class__.__bases__[0] + params.update(cp.__class__.get_params(cp, deep)) # if kwargs is a dict, update params accordingly if isinstance(self.kwargs, dict): params.update(self.kwargs) - if params['missing'] is np.nan: - params['missing'] = None # sklearn doesn't handle nan. see #4725 - if not params.get('eval_metric', True): - del params['eval_metric'] # don't give as None param to Booster if isinstance(params['random_state'], np.random.RandomState): params['random_state'] = params['random_state'].randint( np.iinfo(np.int32).max) @@ -354,10 +421,11 @@ class XGBModel(XGBModelBase): [xgb.callback.reset_learning_rate(custom_rates)] """ - trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight, - base_margin=base_margin, - missing=self.missing, - nthread=self.n_jobs) + train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight, + base_margin=base_margin, + missing=self.missing, + nthread=self.n_jobs) + evals_result = {} if eval_set is not None: @@ -389,7 +457,7 @@ class XGBModel(XGBModelBase): else: params.update({'eval_metric': eval_metric}) - self._Booster = train(params, trainDmatrix, + self._Booster = train(params, train_dmatrix, self.get_num_boosting_rounds(), evals=evals, early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, feval=feval, @@ -419,13 +487,6 @@ class XGBModel(XGBModelBase): If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies of model object and then call ``predict()``. - .. note:: Using ``predict()`` with DART booster - - If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only - some of the trees will be evaluated. This will produce incorrect results if ``data`` is - not the training data. To obtain correct results on test sets, set ``ntree_limit`` to - a nonzero value, e.g. - .. code-block:: python preds = bst.predict(dtest, ntree_limit=num_round) @@ -539,7 +600,8 @@ class XGBModel(XGBModelBase): feature_importances_ : array of shape ``[n_features]`` """ - if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: + if getattr(self, 'booster', None) is not None and self.booster not in { + 'gbtree', 'dart'}: raise AttributeError('Feature importance is not defined for Booster type {}' .format(self.booster)) b = self.get_booster() @@ -555,9 +617,9 @@ class XGBModel(XGBModelBase): .. note:: Coefficients are defined only for linear learners - Coefficients are only defined when the linear model is chosen as base - learner (`booster=gblinear`). It is not defined for other base learner types, such - as tree learners (`booster=gbtree`). + Coefficients are only defined when the linear model is chosen as + base learner (`booster=gblinear`). It is not defined for other base + learner types, such as tree learners (`booster=gbtree`). Returns ------- @@ -599,33 +661,13 @@ class XGBModel(XGBModelBase): return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) +@xgboost_model_doc( + "Implementation of the scikit-learn API for XGBoost classification.", + ['model', 'objective']) class XGBClassifier(XGBModel, XGBClassifierBase): # pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes - __doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \ - + '\n'.join(XGBModel.__doc__.split('\n')[2:]) - - def __init__(self, max_depth=3, learning_rate=0.1, - n_estimators=100, verbosity=1, - objective="binary:logistic", booster='gbtree', - tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0, - min_child_weight=1, max_delta_step=0, subsample=1, - colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, - reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, - random_state=0, missing=None, **kwargs): - super(XGBClassifier, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, - n_estimators=n_estimators, verbosity=verbosity, - objective=objective, booster=booster, tree_method=tree_method, - n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma, - min_child_weight=min_child_weight, - max_delta_step=max_delta_step, subsample=subsample, - colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, - colsample_bynode=colsample_bynode, - reg_alpha=reg_alpha, reg_lambda=reg_lambda, - scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, missing=missing, - **kwargs) + def __init__(self, objective="binary:logistic", **kwargs): + super().__init__(objective=objective, **kwargs) def fit(self, X, y, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, @@ -647,8 +689,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): obj = None if self.n_classes_ > 2: - # Switch to using a multiclass objective in the underlying XGB instance - xgb_options["objective"] = "multi:softprob" + # Switch to using a multiclass objective in the underlying + # XGB instance + xgb_options['objective'] = 'multi:softprob' xgb_options['num_class'] = self.n_classes_ feval = eval_metric if callable(eval_metric) else None @@ -665,7 +708,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if sample_weight_eval_set is None: sample_weight_eval_set = [None] * len(eval_set) evals = list( - DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]), + DMatrix(eval_set[i][0], + label=self._le.transform(eval_set[i][1]), missing=self.missing, weight=sample_weight_eval_set[i], nthread=self.n_jobs) for i in range(len(eval_set)) @@ -686,8 +730,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase): base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) - self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(), - evals=evals, early_stopping_rounds=early_stopping_rounds, + self._Booster = train(xgb_options, train_dmatrix, + self.get_num_boosting_rounds(), + evals=evals, + early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, feval=feval, verbose_eval=verbose, xgb_model=xgb_model, callbacks=callbacks) @@ -696,7 +742,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if evals_result: for val in evals_result.items(): evals_result_key = list(val[1].keys())[0] - evals_result[val[0]][evals_result_key] = val[1][evals_result_key] + evals_result[val[0]][ + evals_result_key] = val[1][evals_result_key] self.evals_result_ = evals_result if early_stopping_rounds is not None: @@ -706,8 +753,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): return self - fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model', - 'Fit gradient boosting classifier', 1) + fit.__doc__ = XGBModel.fit.__doc__.replace( + 'Fit gradient boosting model', + 'Fit gradient boosting classifier', 1) def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True, base_margin=None): @@ -717,15 +765,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): .. note:: This function is not thread safe. For each booster object, predict can only be called from one thread. - If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies - of model object and then call ``predict()``. - - .. note:: Using ``predict()`` with DART booster - - If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only - some of the trees will be evaluated. This will produce incorrect results if ``data`` is - not the training data. To obtain correct results on test sets, set ``ntree_limit`` to - a nonzero value, e.g. + If you want to run prediction using multiple thread, call + ``xgb.copy()`` to make copies of model object and then call + ``predict()``. .. code-block:: python @@ -738,11 +780,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase): output_margin : bool Whether to output the raw untransformed margin value. ntree_limit : int - Limit number of trees in the prediction; defaults to best_ntree_limit if defined - (i.e. it has been trained with early stopping), otherwise 0 (use all trees). + Limit number of trees in the prediction; defaults to + best_ntree_limit if defined (i.e. it has been trained with early + stopping), otherwise 0 (use all trees). validate_features : bool - When this is True, validate that the Booster's and data's feature_names are identical. - Otherwise, it is assumed that the feature_names are the same. + When this is True, validate that the Booster's and data's + feature_names are identical. Otherwise, it is assumed that the + feature_names are the same. + Returns ------- prediction : numpy array @@ -773,9 +818,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): .. note:: This function is not thread safe - For each booster object, predict can only be called from one thread. - If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies - of model object and then call predict + For each booster object, predict can only be called from one + thread. If you want to run prediction using multiple thread, call + ``xgb.copy()`` to make copies of model object and then call predict Parameters ---------- @@ -849,30 +894,26 @@ class XGBClassifier(XGBModel, XGBClassifierBase): return evals_result +@xgboost_model_doc( + "scikit-learn API for XGBoost random forest classification.", + ['model', 'objective'], + extra_parameters=''' + n_estimators : int + Number of trees in random forest to fit. +''') class XGBRFClassifier(XGBClassifier): # pylint: disable=missing-docstring - __doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\ - + '\n'.join(XGBModel.__doc__.split('\n')[2:]) - - def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, - verbosity=1, objective="binary:logistic", n_jobs=1, - gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0, - subsample=0.8, colsample_bytree=1, colsample_bylevel=1, - colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5, - scale_pos_weight=1, base_score=0.5, random_state=0, - missing=None, **kwargs): - super(XGBRFClassifier, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, - n_estimators=n_estimators, verbosity=verbosity, - objective=objective, booster='gbtree', n_jobs=n_jobs, - gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight, - max_delta_step=max_delta_step, - subsample=subsample, colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, - colsample_bynode=colsample_bynode, reg_alpha=reg_alpha, - reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, missing=missing, - **kwargs) + def __init__(self, + learning_rate=1, + subsample=0.8, + colsample_bynode=0.8, + reg_lambda=1e-5, + **kwargs): + super().__init__(learning_rate=learning_rate, + subsample=subsample, + colsample_bynode=colsample_bynode, + reg_lambda=reg_lambda, + **kwargs) def get_xgb_params(self): params = super(XGBRFClassifier, self).get_xgb_params() @@ -883,37 +924,25 @@ class XGBRFClassifier(XGBClassifier): return 1 +@xgboost_model_doc( + "Implementation of the scikit-learn API for XGBoost regression.", + ['estimators', 'model', 'objective']) class XGBRegressor(XGBModel, XGBRegressorBase): # pylint: disable=missing-docstring - __doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\ - + '\n'.join(XGBModel.__doc__.split('\n')[2:]) + def __init__(self, objective="reg:squarederror", **kwargs): + super().__init__(objective=objective, **kwargs) +@xgboost_model_doc( + "scikit-learn API for XGBoost random forest regression.", + ['model', 'objective']) class XGBRFRegressor(XGBRegressor): # pylint: disable=missing-docstring - __doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\ - + '\n'.join(XGBModel.__doc__.split('\n')[2:]) - - def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, - verbosity=1, objective="reg:squarederror", n_jobs=1, - gpu_id=-1, gamma=0, min_child_weight=1, - max_delta_step=0, subsample=0.8, colsample_bytree=1, - colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, - reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5, - random_state=0, missing=None, **kwargs): - super(XGBRFRegressor, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, - n_estimators=n_estimators, verbosity=verbosity, - objective=objective, booster='gbtree', n_jobs=n_jobs, - gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight, - max_delta_step=max_delta_step, subsample=subsample, - colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, - colsample_bynode=colsample_bynode, - reg_alpha=reg_alpha, reg_lambda=reg_lambda, - scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, missing=missing, - **kwargs) + def __init__(self, learning_rate=1, subsample=0.8, colsample_bynode=0.8, + reg_lambda=1e-5, **kwargs): + super().__init__(learning_rate=learning_rate, subsample=subsample, + colsample_bynode=colsample_bynode, + reg_lambda=reg_lambda, **kwargs) def get_xgb_params(self): params = super(XGBRFRegressor, self).get_xgb_params() @@ -924,71 +953,10 @@ class XGBRFRegressor(XGBRegressor): return 1 -class XGBRanker(XGBModel): - # pylint: disable=missing-docstring,too-many-arguments,invalid-name - """Implementation of the Scikit-Learn API for XGBoost Ranking. - - Parameters - ---------- - max_depth : int - Maximum tree depth for base learners. - learning_rate : float - Boosting learning rate (xgb's "eta") - n_estimators : int - Number of boosted trees to fit. - verbosity : int - The degree of verbosity. Valid values are 0 (silent) - 3 (debug). - objective : string - Specify the learning task and the corresponding learning objective. - The objective name must start with "rank:". - booster: string - Specify which booster to use: gbtree, gblinear or dart. - n_jobs : int - Number of parallel threads used to run xgboost. - gamma : float - Minimum loss reduction required to make a further partition on a leaf node of the tree. - min_child_weight : int - Minimum sum of instance weight(hessian) needed in a child. - max_delta_step : int - Maximum delta step we allow each tree's weight estimation to be. - subsample : float - Subsample ratio of the training instance. - colsample_bytree : float - Subsample ratio of columns when constructing each tree. - colsample_bylevel : float - Subsample ratio of columns for each level. - colsample_bynode : float - Subsample ratio of columns for each split. - reg_alpha : float (xgb's alpha) - L1 regularization term on weights - reg_lambda : float (xgb's lambda) - L2 regularization term on weights - scale_pos_weight : float - Balancing of positive and negative weights. - base_score: - The initial prediction score of all instances, global bias. - random_state : int - Random number seed. - - .. note:: - - Using gblinear booster with shotgun updater is nondeterministic as - it uses Hogwild algorithm. - - missing : float, optional - Value in the data which needs to be present as a missing value. If - None, defaults to np.nan. - \\*\\*kwargs : dict, optional - Keyword arguments for XGBoost Booster object. Full documentation of parameters can - be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. - Attempting to set a parameter via the constructor args and \\*\\*kwargs dict - simultaneously will result in a TypeError. - - .. note:: \\*\\*kwargs unsupported by scikit-learn - - \\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters - passed via this argument will interact properly with scikit-learn. - +@xgboost_model_doc( + 'Implementation of the Scikit-Learn API for XGBoost Ranking.', + ['estimators', 'model'], + end_note=''' Note ---- A custom objective function is currently not supported by XGBRanker. @@ -998,9 +966,9 @@ class XGBRanker(XGBModel): ---- Query group information is required for ranking tasks. - Before fitting the model, your data need to be sorted by query group. When - fitting the model, you need to provide an additional array that - contains the size of each query group. + Before fitting the model, your data need to be sorted by query + group. When fitting the model, you need to provide an additional array + that contains the size of each query group. For example, if your original data look like: @@ -1023,29 +991,11 @@ class XGBRanker(XGBModel): +-------+-----------+---------------+ then your group array should be ``[3, 4]``. - - """ - - def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, - verbosity=1, objective="rank:pairwise", booster='gbtree', - tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0, - min_child_weight=1, max_delta_step=0, subsample=1, - colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, - reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, - random_state=0, missing=None, **kwargs): - - super(XGBRanker, self).__init__( - max_depth=max_depth, learning_rate=learning_rate, - n_estimators=n_estimators, verbosity=verbosity, - objective=objective, booster=booster, tree_method=tree_method, - n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma, - min_child_weight=min_child_weight, max_delta_step=max_delta_step, - subsample=subsample, colsample_bytree=colsample_bytree, - colsample_bylevel=colsample_bylevel, - colsample_bynode=colsample_bynode, reg_alpha=reg_alpha, - reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight, - base_score=base_score, random_state=random_state, missing=missing, - **kwargs) +''') +class XGBRanker(XGBModel): + # pylint: disable=missing-docstring,too-many-arguments,invalid-name + def __init__(self, objective='rank:pairwise', **kwargs): + super().__init__(objective=objective, **kwargs) if callable(self.objective): raise ValueError( "custom objective function not supported by XGBRanker") @@ -1058,8 +1008,7 @@ class XGBRanker(XGBModel): early_stopping_rounds=None, verbose=False, xgb_model=None, callbacks=None): # pylint: disable = attribute-defined-outside-init,arguments-differ - """ - Fit gradient boosting ranker + """Fit gradient boosting ranker Parameters ---------- @@ -1068,17 +1017,17 @@ class XGBRanker(XGBModel): y : array_like Labels group : array_like - Size of each query group of training data. Should have as many elements as - the query groups in the training data + Size of each query group of training data. Should have as many + elements as the query groups in the training data sample_weight : array_like Query group weights .. note:: Weights are per-group for ranking tasks - In ranking task, one weight is assigned to each query group (not each - data point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign - weights to individual data points. + In ranking task, one weight is assigned to each query group + (not each data point). This is because we only care about the + relative ordering of data points within each group, so it + doesn't make sense to assign weights to individual data points. base_margin : array_like Global bias for each instance. @@ -1112,24 +1061,26 @@ class XGBRanker(XGBModel): The method returns the model from the last iteration (not the best one). If there's more than one item in **eval_set**, the last entry will be used for early stopping. - If there's more than one metric in **eval_metric**, the last metric will be - used for early stopping. - If early stopping occurs, the model will have three additional fields: - ``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``. + If there's more than one metric in **eval_metric**, the last metric + will be used for early stopping. + If early stopping occurs, the model will have three additional + fields: ``clf.best_score``, ``clf.best_iteration`` and + ``clf.best_ntree_limit``. verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. xgb_model : str - file name of stored XGBoost model or 'Booster' instance XGBoost model to be - loaded before training (allows training continuation). + file name of stored XGBoost model or 'Booster' instance XGBoost + model to be loaded before training (allows training continuation). callbacks : list of callback functions - List of callback functions that are applied at end of each iteration. - It is possible to use predefined callbacks by using :ref:`callback_api`. - Example: + List of callback functions that are applied at end of each + iteration. It is possible to use predefined callbacks by using + :ref:`callback_api`. Example: .. code-block:: python [xgb.callback.reset_learning_rate(custom_rates)] + """ # check if group information is provided if group is None: @@ -1137,11 +1088,14 @@ class XGBRanker(XGBModel): if eval_set is not None: if eval_group is None: - raise ValueError("eval_group is required if eval_set is not None") + raise ValueError( + "eval_group is required if eval_set is not None") if len(eval_group) != len(eval_set): - raise ValueError("length of eval_group should match that of eval_set") + raise ValueError( + "length of eval_group should match that of eval_set") if any(group is None for group in eval_group): - raise ValueError("group is required for all eval datasets for ranking task") + raise ValueError( + "group is required for all eval datasets for ranking task") def _dmat_init(group, **params): ret = DMatrix(**params) @@ -1158,9 +1112,13 @@ class XGBRanker(XGBModel): if eval_set is not None: if sample_weight_eval_set is None: sample_weight_eval_set = [None] * len(eval_set) - evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1], - missing=self.missing, weight=sample_weight_eval_set[i], - nthread=self.n_jobs) for i in range(len(eval_set))] + evals = [_dmat_init(eval_group[i], + data=eval_set[i][0], + label=eval_set[i][1], + missing=self.missing, + weight=sample_weight_eval_set[i], + nthread=self.n_jobs) + for i in range(len(eval_set))] nevals = len(evals) eval_names = ["eval_{}".format(i) for i in range(nevals)] evals = list(zip(evals, eval_names)) @@ -1172,13 +1130,14 @@ class XGBRanker(XGBModel): feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: if callable(eval_metric): - raise ValueError('Custom evaluation metric is not yet supported' + - 'for XGBRanker.') + raise ValueError( + 'Custom evaluation metric is not yet supported for XGBRanker.') params.update({'eval_metric': eval_metric}) self._Booster = train(params, train_dmatrix, self.n_estimators, - early_stopping_rounds=early_stopping_rounds, evals=evals, + early_stopping_rounds=early_stopping_rounds, + evals=evals, evals_result=evals_result, feval=feval, verbose_eval=verbose, xgb_model=xgb_model, callbacks=callbacks) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 142b98e71..ca95f46cb 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -3,7 +3,8 @@ # pylint: disable=too-many-branches, too-many-statements """Training Library containing training routines.""" import numpy as np -from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException +from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv +from .core import EarlyStopException from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import rabit from . import callback @@ -37,10 +38,11 @@ def _train_internal(params, dtrain, _params = dict(params) if isinstance(params, list) else params - if 'num_parallel_tree' in _params: + if 'num_parallel_tree' in _params and params[ + 'num_parallel_tree'] is not None: num_parallel_tree = _params['num_parallel_tree'] nboost //= num_parallel_tree - if 'num_class' in _params: + if 'num_class' in _params and _params['num_class'] is not None: nboost //= _params['num_class'] # Distributed code: Load the checkpoint from rabit. diff --git a/tests/python/test_early_stopping.py b/tests/python/test_early_stopping.py index e709ecc24..1f8874bf5 100644 --- a/tests/python/test_early_stopping.py +++ b/tests/python/test_early_stopping.py @@ -22,17 +22,17 @@ class TestEarlyStopping(unittest.TestCase): y = digits['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - clf1 = xgb.XGBClassifier() + clf1 = xgb.XGBClassifier(learning_rate=0.1) clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc", eval_set=[(X_test, y_test)]) - clf2 = xgb.XGBClassifier() + clf2 = xgb.XGBClassifier(learning_rate=0.1) clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc", eval_set=[(X_test, y_test)]) # should be the same assert clf1.best_score == clf2.best_score assert clf1.best_score != 1 # check overfit - clf3 = xgb.XGBClassifier() + clf3 = xgb.XGBClassifier(learning_rate=0.1) clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc", eval_set=[(X_test, y_test)]) assert clf3.best_score == 1 diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 19195d857..3c0092282 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -6,6 +6,7 @@ import os import shutil import pytest import unittest +import json rng = np.random.RandomState(1994) @@ -117,9 +118,10 @@ def test_feature_importances_weight(): digits = load_digits(2) y = digits['target'] X = digits['data'] - xgb_model = xgb.XGBClassifier( - random_state=0, tree_method="exact", importance_type="weight").fit(X, y) - + xgb_model = xgb.XGBClassifier(random_state=0, + tree_method="exact", + learning_rate=0.1, + importance_type="weight").fit(X, y) exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0., 0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0., 0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0., @@ -134,12 +136,16 @@ def test_feature_importances_weight(): import pandas as pd y = pd.Series(digits['target']) X = pd.DataFrame(digits['data']) - xgb_model = xgb.XGBClassifier( - random_state=0, tree_method="exact", importance_type="weight").fit(X, y) + xgb_model = xgb.XGBClassifier(random_state=0, + tree_method="exact", + learning_rate=0.1, + importance_type="weight").fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) - xgb_model = xgb.XGBClassifier( - random_state=0, tree_method="exact", importance_type="weight").fit(X, y) + xgb_model = xgb.XGBClassifier(random_state=0, + tree_method="exact", + learning_rate=0.1, + importance_type="weight").fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) @@ -151,7 +157,9 @@ def test_feature_importances_gain(): y = digits['target'] X = digits['data'] xgb_model = xgb.XGBClassifier( - random_state=0, tree_method="exact", importance_type="gain").fit(X, y) + random_state=0, tree_method="exact", + learning_rate=0.1, + importance_type="gain").fit(X, y) exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00326159, 0., 0., 0., 0., 0., 0., 0., 0., @@ -169,11 +177,15 @@ def test_feature_importances_gain(): y = pd.Series(digits['target']) X = pd.DataFrame(digits['data']) xgb_model = xgb.XGBClassifier( - random_state=0, tree_method="exact", importance_type="gain").fit(X, y) + random_state=0, tree_method="exact", + learning_rate=0.1, + importance_type="gain").fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) xgb_model = xgb.XGBClassifier( - random_state=0, tree_method="exact", importance_type="gain").fit(X, y) + random_state=0, tree_method="exact", + learning_rate=0.1, + importance_type="gain").fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) @@ -191,6 +203,10 @@ def test_num_parallel_tree(): dump = bst.get_booster().get_dump(dump_format='json') assert len(dump) == 4 + config = json.loads(bst.get_booster().save_config()) + assert int(config['learner']['gradient_booster']['gbtree_train_param'][ + 'num_parallel_tree']) == 4 + def test_boston_housing_regression(): from sklearn.metrics import mean_squared_error @@ -244,7 +260,7 @@ def test_parameter_tuning(): boston = load_boston() y = boston['target'] X = boston['data'] - xgb_model = xgb.XGBRegressor() + xgb_model = xgb.XGBRegressor(learning_rate=0.1) clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6], 'n_estimators': [50, 100, 200]}, cv=3, verbose=1, iid=True)