sklearn api for ranking (#3560)

* added xgbranker

* fixed predict method and ranking test

* reformatted code in accordance with pep8

* fixed lint error

* fixed docstring and added checks on objective

* added ranking demo for python

* fixed suffix in rank.py
This commit is contained in:
Shiki-H
2018-08-21 11:26:48 -04:00
committed by Philip Hyunsu Cho
parent b13c3a8bcc
commit 24a268a2e3
6 changed files with 359 additions and 7 deletions

View File

@@ -77,6 +77,40 @@ def test_multiclass_classification():
check_pred(preds4, labels)
def test_ranking():
tm._skip_if_no_sklearn()
# generate random data
x_train = np.random.rand(1000, 10)
y_train = np.random.randint(5, size=1000)
train_group = np.repeat(50, 20)
x_valid = np.random.rand(200, 10)
y_valid = np.random.randint(5, size=200)
valid_group = np.repeat(50, 4)
x_test = np.random.rand(100, 10)
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
'gamma': 1.0, 'min_child_weight': 0.1,
'max_depth': 6, 'n_estimators': 4}
model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, train_group,
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
pred = model.predict(x_test)
train_data = xgb.DMatrix(x_train, y_train)
valid_data = xgb.DMatrix(x_valid, y_valid)
test_data = xgb.DMatrix(x_test)
train_data.set_group(train_group)
valid_data.set_group(valid_group)
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
'min_child_weight': 0.1, 'max_depth': 6}
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
evals=[(valid_data, 'validation')])
pred_orig = xgb_model_orig.predict(test_data)
np.testing.assert_almost_equal(pred, pred_orig)
def test_feature_importances():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits