diff --git a/demo/rank/README.md b/demo/rank/README.md index 55dcb4ee5..f8bd5b1b7 100644 --- a/demo/rank/README.md +++ b/demo/rank/README.md @@ -1,6 +1,6 @@ Learning to rank ==== -XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank. +XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/tutorials/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank. ### Parameters The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives: @@ -15,14 +15,27 @@ For more usage details please refer to the [binary classification demo](../binar Instructions ==== The dataset for ranking demo is from LETOR04 MQ2008 fold1. -You can use the following command to run the example: +Before running the examples, you need to get the data by running: -Get the data: ``` ./wgetdata.sh ``` +### Command Line Run the example: ``` ./runexp.sh ``` + +### Python +There are two ways of doing ranking in python. + +Run the example using `xgboost.train`: +``` +python rank.py +``` + +Run the example using `XGBRanker`: +``` +python rank_sklearn.py +``` diff --git a/demo/rank/rank.py b/demo/rank/rank.py new file mode 100644 index 000000000..2bc260574 --- /dev/null +++ b/demo/rank/rank.py @@ -0,0 +1,41 @@ +#!/usr/bin/python +import xgboost as xgb +from xgboost import DMatrix +from sklearn.datasets import load_svmlight_file + + +# This script demonstrate how to do ranking with xgboost.train +x_train, y_train = load_svmlight_file("mq2008.train") +x_valid, y_valid = load_svmlight_file("mq2008.vali") +x_test, y_test = load_svmlight_file("mq2008.test") + +group_train = [] +with open("mq2008.train.group", "r") as f: + data = f.readlines() + for line in data: + group_train.append(int(line.split("\n")[0])) + +group_valid = [] +with open("mq2008.vali.group", "r") as f: + data = f.readlines() + for line in data: + group_valid.append(int(line.split("\n")[0])) + +group_test = [] +with open("mq2008.test.group", "r") as f: + data = f.readlines() + for line in data: + group_test.append(int(line.split("\n")[0])) + +train_dmatrix = DMatrix(x_train, y_train) +valid_dmatrix = DMatrix(x_valid, y_valid) +test_dmatrix = DMatrix(x_test) + +train_dmatrix.set_group(group_train) +valid_dmatrix.set_group(group_valid) + +params = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0, + 'min_child_weight': 0.1, 'max_depth': 6} +xgb_model = xgb.train(params, train_dmatrix, num_boost_round=4, + evals=[(valid_dmatrix, 'validation')]) +pred = xgb_model.predict(test_dmatrix) diff --git a/demo/rank/rank_sklearn.py b/demo/rank/rank_sklearn.py new file mode 100644 index 000000000..1d8341d17 --- /dev/null +++ b/demo/rank/rank_sklearn.py @@ -0,0 +1,35 @@ +#!/usr/bin/python +import xgboost as xgb +from sklearn.datasets import load_svmlight_file + + +# This script demonstrate how to do ranking with XGBRanker +x_train, y_train = load_svmlight_file("mq2008.train") +x_valid, y_valid = load_svmlight_file("mq2008.vali") +x_test, y_test = load_svmlight_file("mq2008.test") + +group_train = [] +with open("mq2008.train.group", "r") as f: + data = f.readlines() + for line in data: + group_train.append(int(line.split("\n")[0])) + +group_valid = [] +with open("mq2008.vali.group", "r") as f: + data = f.readlines() + for line in data: + group_valid.append(int(line.split("\n")[0])) + +group_test = [] +with open("mq2008.test.group", "r") as f: + data = f.readlines() + for line in data: + group_test.append(int(line.split("\n")[0])) + +params = {'objective': 'rank:pairwise', 'learning_rate': 0.1, + 'gamma': 1.0, 'min_child_weight': 0.1, + 'max_depth': 6, 'n_estimators': 4} +model = xgb.sklearn.XGBRanker(**params) +model.fit(x_train, y_train, group_train, + eval_set=[(x_valid, y_valid)], eval_group=[group_valid]) +pred = model.predict(x_test) diff --git a/dmlc-core b/dmlc-core index f2afdc778..459ab734d 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit f2afdc7788ee8ed6fd06cc095b6838d4ce61bb5a +Subproject commit 459ab734d15acd68fd437abf845c7c1730b5a38f diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index f64cc983a..f78bf9439 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -101,7 +101,7 @@ class XGBModel(XGBModelBase): None, defaults to np.nan. **kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of parameters can - be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md. + be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. Attempting to set a parameter via the constructor args and **kwargs dict simultaneously will result in a TypeError. Note: @@ -259,7 +259,7 @@ class XGBModel(XGBModelBase): instance weights on the i-th validation set. eval_metric : str, callable, optional If a str, should be a built-in evaluation metric to use. See - doc/parameter.md. If callable, a custom evaluation metric. The call + doc/parameter.rst. If callable, a custom evaluation metric. The call signature is func(y_predicted, y_true) where y_true will be a DMatrix object such that you may need to call the get_label method. It must return a str, value pair where the str is a name @@ -465,7 +465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): instance weights on the i-th validation set. eval_metric : str, callable, optional If a str, should be a built-in evaluation metric to use. See - doc/parameter.md. If callable, a custom evaluation metric. The call + doc/parameter.rst. If callable, a custom evaluation metric. The call signature is func(y_predicted, y_true) where y_true will be a DMatrix object such that you may need to call the get_label method. It must return a str, value pair where the str is a name @@ -679,3 +679,232 @@ class XGBRegressor(XGBModel, XGBRegressorBase): # pylint: disable=missing-docstring __doc__ = """Implementation of the scikit-learn API for XGBoost regression. """ + '\n'.join(XGBModel.__doc__.split('\n')[2:]) + + +class XGBRanker(XGBModel): + # pylint: disable=missing-docstring,too-many-arguments,invalid-name + """Implementation of the Scikit-Learn API for XGBoost Ranking. + + Parameters + ---------- + max_depth : int + Maximum tree depth for base learners. + learning_rate : float + Boosting learning rate (xgb's "eta") + n_estimators : int + Number of boosted trees to fit. + silent : boolean + Whether to print messages while running boosting. + objective : string + Specify the learning task and the corresponding learning objective. + Only "rank:pairwise" is supported currently. + booster: string + Specify which booster to use: gbtree, gblinear or dart. + nthread : int + Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs) + n_jobs : int + Number of parallel threads used to run xgboost. (replaces nthread) + gamma : float + Minimum loss reduction required to make a further partition on a leaf node of the tree. + min_child_weight : int + Minimum sum of instance weight(hessian) needed in a child. + max_delta_step : int + Maximum delta step we allow each tree's weight estimation to be. + subsample : float + Subsample ratio of the training instance. + colsample_bytree : float + Subsample ratio of columns when constructing each tree. + colsample_bylevel : float + Subsample ratio of columns for each split, in each level. + reg_alpha : float (xgb's alpha) + L1 regularization term on weights + reg_lambda : float (xgb's lambda) + L2 regularization term on weights + scale_pos_weight : float + Balancing of positive and negative weights. + base_score: + The initial prediction score of all instances, global bias. + seed : int + Random number seed. (Deprecated, please use random_state) + random_state : int + Random number seed. (replaces seed) + missing : float, optional + Value in the data which needs to be present as a missing value. If + None, defaults to np.nan. + **kwargs : dict, optional + Keyword arguments for XGBoost Booster object. Full documentation of parameters can + be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. + Attempting to set a parameter via the constructor args and **kwargs dict simultaneously + will result in a TypeError. + Note: + **kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via + this argument will interact properly with Sklearn. + + Note + ---- + A custom objective function is currently not supported by XGBRanker. + + Group information is required for ranking tasks. Before fitting the model, your data need to + be sorted by group. When fitting the model, you need to provide an additional array that + contains the size of each group. + + For example, if your original data look like: + + | qid | label | features | + | 1 | 0 | x_1 | + | 1 | 1 | x_2 | + | 1 | 0 | x_3 | + | 2 | 0 | x_4 | + | 2 | 1 | x_5 | + | 2 | 1 | x_6 | + | 2 | 1 | x_7 | + + then your group array should be [3, 4]. + """ + + def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, + silent=True, objective="rank:pairwise", booster='gbtree', + n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0, + subsample=1, colsample_bytree=1, colsample_bylevel=1, + reg_alpha=0, reg_lambda=1, scale_pos_weight=1, + base_score=0.5, random_state=0, seed=None, missing=None, **kwargs): + + super(XGBRanker, self).__init__(max_depth, learning_rate, + n_estimators, silent, objective, booster, + n_jobs, nthread, gamma, min_child_weight, max_delta_step, + subsample, colsample_bytree, colsample_bylevel, + reg_alpha, reg_lambda, scale_pos_weight, + base_score, random_state, seed, missing) + if callable(self.objective): + raise ValueError("custom objective function not supported by XGBRanker") + elif self.objective != "rank:pairwise": + raise ValueError("please use XGBRanker for ranking task") + + def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None, + eval_group=None, eval_metric=None, early_stopping_rounds=None, + verbose=False, xgb_model=None): + # pylint: disable = attribute-defined-outside-init,arguments-differ + """ + Fit the gradient boosting model + + Parameters + ---------- + X : array_like + Feature matrix + y : array_like + Labels + group : array_like + group size of training data + sample_weight : array_like + instance weights + eval_set : list, optional + A list of (X, y) tuple pairs to use as a validation set for + early-stopping + sample_weight_eval_set : list, optional + A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of + instance weights on the i-th validation set. + eval_group : list of arrays, optional + A list that contains the group size corresponds to each + (X, y) pair in eval_set + eval_metric : str, callable, optional + If a str, should be a built-in evaluation metric to use. See + doc/parameter.rst. If callable, a custom evaluation metric. The call + signature is func(y_predicted, y_true) where y_true will be a + DMatrix object such that you may need to call the get_label + method. It must return a str, value pair where the str is a name + for the evaluation and value is the value of the evaluation + function. This objective is always minimized. + early_stopping_rounds : int + Activates early stopping. Validation error needs to decrease at + least every round(s) to continue training. + Requires at least one item in evals. If there's more than one, + will use the last. Returns the model from the last iteration + (not the best one). If early stopping occurs, the model will + have three additional fields: bst.best_score, bst.best_iteration + and bst.best_ntree_limit. + (Use bst.best_ntree_limit to get the correct value if num_parallel_tree + and/or num_class appears in the parameters) + verbose : bool + If `verbose` and an evaluation set is used, writes the evaluation + metric measured on the validation set to stderr. + xgb_model : str + file name of stored xgb model or 'Booster' instance Xgb model to be + loaded before training (allows training continuation). + """ + # check if group information is provided + if group is None: + raise ValueError("group is required for ranking task") + + if eval_set is not None: + if eval_group is None: + raise ValueError("eval_group is required if eval_set is not None") + elif len(eval_group) != len(eval_set): + raise ValueError("length of eval_group should match that of eval_set") + elif any(group is None for group in eval_group): + raise ValueError("group is required for all eval datasets for ranking task") + + def _dmat_init(group, **params): + ret = DMatrix(**params) + ret.set_group(group) + return ret + + if sample_weight is not None: + train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight, + missing=self.missing, nthread=self.n_jobs) + else: + train_dmatrix = _dmat_init(group, data=X, label=y, + missing=self.missing, nthread=self.n_jobs) + + evals_result = {} + + if eval_set is not None: + if sample_weight_eval_set is None: + sample_weight_eval_set = [None] * len(eval_set) + evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1], + missing=self.missing, weight=sample_weight_eval_set[i], + nthread=self.n_jobs) for i in range(len(eval_set))] + nevals = len(evals) + eval_names = ["eval_{}".format(i) for i in range(nevals)] + evals = list(zip(evals, eval_names)) + else: + evals = () + + params = self.get_xgb_params() + + feval = eval_metric if callable(eval_metric) else None + if eval_metric is not None: + if callable(eval_metric): + eval_metric = None + else: + params.update({'eval_metric': eval_metric}) + + self._Booster = train(params, train_dmatrix, + self.n_estimators, + early_stopping_rounds=early_stopping_rounds, evals=evals, + evals_result=evals_result, feval=feval, + verbose_eval=verbose, xgb_model=xgb_model) + + self.objective = params["objective"] + + if evals_result: + for val in evals_result.items(): + evals_result_key = list(val[1].keys())[0] + evals_result[val[0]][evals_result_key] = val[1][evals_result_key] + self.evals_result = evals_result + + if early_stopping_rounds is not None: + self.best_score = self._Booster.best_score + self.best_iteration = self._Booster.best_iteration + self.best_ntree_limit = self._Booster.best_ntree_limit + + return self + + def predict(self, data, output_margin=False, ntree_limit=0): + + test_dmatrix = DMatrix(data, missing=self.missing) + if ntree_limit is None: + ntree_limit = getattr(self, "best_ntree_limit", 0) + + return self.get_booster().predict(test_dmatrix, + output_margin=output_margin, + ntree_limit=ntree_limit) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index b184b3952..7508d1b80 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -77,6 +77,40 @@ def test_multiclass_classification(): check_pred(preds4, labels) +def test_ranking(): + tm._skip_if_no_sklearn() + # generate random data + x_train = np.random.rand(1000, 10) + y_train = np.random.randint(5, size=1000) + train_group = np.repeat(50, 20) + x_valid = np.random.rand(200, 10) + y_valid = np.random.randint(5, size=200) + valid_group = np.repeat(50, 4) + x_test = np.random.rand(100, 10) + + params = {'objective': 'rank:pairwise', 'learning_rate': 0.1, + 'gamma': 1.0, 'min_child_weight': 0.1, + 'max_depth': 6, 'n_estimators': 4} + model = xgb.sklearn.XGBRanker(**params) + model.fit(x_train, y_train, train_group, + eval_set=[(x_valid, y_valid)], eval_group=[valid_group]) + pred = model.predict(x_test) + + train_data = xgb.DMatrix(x_train, y_train) + valid_data = xgb.DMatrix(x_valid, y_valid) + test_data = xgb.DMatrix(x_test) + train_data.set_group(train_group) + valid_data.set_group(valid_group) + + params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0, + 'min_child_weight': 0.1, 'max_depth': 6} + xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4, + evals=[(valid_data, 'validation')]) + pred_orig = xgb_model_orig.predict(test_data) + + np.testing.assert_almost_equal(pred, pred_orig) + + def test_feature_importances(): tm._skip_if_no_sklearn() from sklearn.datasets import load_digits