Add support for cross-validation using query ID (#4474)
* adding support for matrix slicing with query ID for cross-validation * hail mary test of unrar installation for windows tests * trying to modify tests to run in Github CI * Remove dependency on wget and unrar * Save error log from R test * Relax assertion in test_training * Use int instead of bool in C function interface * Revise R interface * Add XGDMatrixSliceDMatrixEx and keep old XGDMatrixSliceDMatrix for API compatibility
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
5a567ec249
commit
278562db13
@@ -1,6 +1,16 @@
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import xgboost
|
||||
import sys
|
||||
import os
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
import unittest
|
||||
import itertools
|
||||
import glob
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
|
||||
def test_ranking_with_unweighted_data():
|
||||
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
|
||||
@@ -63,3 +73,110 @@ def test_ranking_with_weighted_data():
|
||||
# the ranking predictor will first try to correctly sort the last query group
|
||||
# before correctly sorting other groups.
|
||||
assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:]))
|
||||
|
||||
|
||||
class TestRanking(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
# download the test data
|
||||
cls.dpath = 'demo/rank/'
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = cls.dpath + '/MQ2008.zip'
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
|
||||
x_train, y_train, qid_train, x_test, y_test, qid_test, x_valid, y_valid, qid_valid = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
# instantiate the matrices
|
||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
|
||||
cls.dtest = xgboost.DMatrix(x_test, y_test)
|
||||
# set the group counts from the query IDs
|
||||
cls.dtrain.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_train)])
|
||||
cls.dtest.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_test)])
|
||||
cls.dvalid.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_valid)])
|
||||
# save the query IDs for testing
|
||||
cls.qid_train = qid_train
|
||||
cls.qid_test = qid_test
|
||||
cls.qid_valid = qid_valid
|
||||
|
||||
# model training parameters
|
||||
cls.params = {'objective': 'rank:pairwise',
|
||||
'booster': 'gbtree',
|
||||
'silent': 0,
|
||||
'eval_metric': ['ndcg']
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""
|
||||
Cleanup test artifacts from download and unpacking
|
||||
:return:
|
||||
"""
|
||||
os.remove(cls.dpath + "MQ2008.zip")
|
||||
shutil.rmtree(cls.dpath + "MQ2008")
|
||||
|
||||
def test_training(self):
|
||||
"""
|
||||
Train an XGBoost ranking model
|
||||
"""
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')]
|
||||
bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, evals=watchlist)
|
||||
assert bst.best_score > 0.98
|
||||
|
||||
def test_cv(self):
|
||||
"""
|
||||
Test cross-validation with a group specified
|
||||
"""
|
||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
self.assertSetEqual(set(cv.keys()), {'test-ndcg-mean', 'train-ndcg-mean', 'test-ndcg-std', 'train-ndcg-std'},
|
||||
"CV results dict key mismatch")
|
||||
|
||||
def test_cv_no_shuffle(self):
|
||||
"""
|
||||
Test cross-validation with a group specified
|
||||
"""
|
||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == 4
|
||||
|
||||
def test_get_group(self):
|
||||
"""
|
||||
Retrieve the group number from the dmatrix
|
||||
"""
|
||||
# control that should work
|
||||
self.dtrain.get_uint_info('root_index')
|
||||
# test the new getter
|
||||
self.dtrain.get_uint_info('group_ptr')
|
||||
|
||||
for d, qid in [(self.dtrain, self.qid_train),
|
||||
(self.dvalid, self.qid_valid),
|
||||
(self.dtest, self.qid_test)]:
|
||||
# size of each group
|
||||
group_sizes = np.array([len(list(items))
|
||||
for _key, items in itertools.groupby(qid)])
|
||||
# indexes of group boundaries
|
||||
group_limits = d.get_uint_info('group_ptr')
|
||||
assert len(group_limits) == len(group_sizes)+1
|
||||
assert np.array_equal(np.diff(group_limits), group_sizes)
|
||||
assert np.array_equal(
|
||||
group_sizes, np.diff(d.get_uint_info('group_ptr')))
|
||||
assert np.array_equal(group_sizes, np.diff(d.get_uint_info('group_ptr')))
|
||||
assert np.array_equal(group_limits, d.get_uint_info('group_ptr'))
|
||||
|
||||
Reference in New Issue
Block a user