sklearn api for ranking (#3560)
* added xgbranker * fixed predict method and ranking test * reformatted code in accordance with pep8 * fixed lint error * fixed docstring and added checks on objective * added ranking demo for python * fixed suffix in rank.py
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
b13c3a8bcc
commit
24a268a2e3
@@ -77,6 +77,40 @@ def test_multiclass_classification():
|
||||
check_pred(preds4, labels)
|
||||
|
||||
|
||||
def test_ranking():
|
||||
tm._skip_if_no_sklearn()
|
||||
# generate random data
|
||||
x_train = np.random.rand(1000, 10)
|
||||
y_train = np.random.randint(5, size=1000)
|
||||
train_group = np.repeat(50, 20)
|
||||
x_valid = np.random.rand(200, 10)
|
||||
y_valid = np.random.randint(5, size=200)
|
||||
valid_group = np.repeat(50, 4)
|
||||
x_test = np.random.rand(100, 10)
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
'max_depth': 6, 'n_estimators': 4}
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, train_group,
|
||||
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
||||
pred = model.predict(x_test)
|
||||
|
||||
train_data = xgb.DMatrix(x_train, y_train)
|
||||
valid_data = xgb.DMatrix(x_valid, y_valid)
|
||||
test_data = xgb.DMatrix(x_test)
|
||||
train_data.set_group(train_group)
|
||||
valid_data.set_group(valid_group)
|
||||
|
||||
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||
'min_child_weight': 0.1, 'max_depth': 6}
|
||||
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
|
||||
evals=[(valid_data, 'validation')])
|
||||
pred_orig = xgb_model_orig.predict(test_data)
|
||||
|
||||
np.testing.assert_almost_equal(pred, pred_orig)
|
||||
|
||||
|
||||
def test_feature_importances():
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
Reference in New Issue
Block a user