Add support for cross-validation using query ID (#4474)
* adding support for matrix slicing with query ID for cross-validation * hail mary test of unrar installation for windows tests * trying to modify tests to run in Github CI * Remove dependency on wget and unrar * Save error log from R test * Relax assertion in test_training * Use int instead of bool in C function interface * Revise R interface * Add XGDMatrixSliceDMatrixEx and keep old XGDMatrixSliceDMatrix for API compatibility
This commit is contained in:
parent
5a567ec249
commit
278562db13
@ -89,3 +89,5 @@ List of Contributors
|
||||
* [Sam Wilkinson](https://samwilkinson.io)
|
||||
* [Matthew Jones](https://github.com/mt-jones)
|
||||
* [Jiaxiang Li](https://github.com/JiaxiangBU)
|
||||
* [Bryan Woods](https://github.com/bryan-woods)
|
||||
- Bryan added support for cross-validation for the ranking objective
|
||||
|
||||
2
Jenkinsfile
vendored
2
Jenkinsfile
vendored
@ -340,6 +340,8 @@ def TestR(args) {
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_test_rpkg.sh
|
||||
"""
|
||||
// Save error log, if any
|
||||
archiveArtifacts artifacts: "xgboost.Rcheck/00install.out", allowEmptyArchive: true
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,9 +136,10 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
||||
idxvec[i] = INTEGER(idxset)[i] - 1;
|
||||
}
|
||||
DMatrixHandle res;
|
||||
CHECK_CALL(XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle),
|
||||
BeginPtr(idxvec), len,
|
||||
&res));
|
||||
CHECK_CALL(XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
|
||||
BeginPtr(idxvec), len,
|
||||
&res,
|
||||
0));
|
||||
ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||
R_API_END();
|
||||
|
||||
@ -221,6 +221,20 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int *idxset,
|
||||
bst_ulong len,
|
||||
DMatrixHandle *out);
|
||||
/*!
|
||||
* \brief create a new dmatrix from sliced content of existing matrix
|
||||
* \param handle instance of data matrix to be sliced
|
||||
* \param idxset index set
|
||||
* \param len length of index set
|
||||
* \param out a sliced new matrix
|
||||
* \param allow_groups allow slicing of an array with groups
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
||||
const int *idxset,
|
||||
bst_ulong len,
|
||||
DMatrixHandle *out,
|
||||
int allow_groups);
|
||||
/*!
|
||||
* \brief free space in data matrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
|
||||
@ -279,7 +279,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat
|
||||
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
|
||||
|
||||
jint ret = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
|
||||
// default to not allowing slicing with group ID specified -- feel free to add if necessary
|
||||
jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0);
|
||||
setHandle(jenv, jout, result);
|
||||
//release
|
||||
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
|
||||
|
||||
@ -795,13 +795,15 @@ class DMatrix(object):
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def slice(self, rindex):
|
||||
def slice(self, rindex, allow_groups=False):
|
||||
"""Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rindex : list
|
||||
List of indices to be selected.
|
||||
allow_groups : boolean
|
||||
Allow slicing of a matrix with a groups attribute
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -811,10 +813,11 @@ class DMatrix(object):
|
||||
res = DMatrix(None, feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
res.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle)))
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrixEx(self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle),
|
||||
ctypes.c_int(1 if allow_groups else 0)))
|
||||
return res
|
||||
|
||||
@property
|
||||
|
||||
@ -234,6 +234,56 @@ class CVPack(object):
|
||||
return self.bst.eval_set(self.watchlist, iteration, feval)
|
||||
|
||||
|
||||
def groups_to_rows(groups, boundaries):
|
||||
"""
|
||||
Given group row boundaries, convert ground indexes to row indexes
|
||||
:param groups: list of groups for testing
|
||||
:param boundaries: rows index limits of each group
|
||||
:return: row in group
|
||||
"""
|
||||
return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups])
|
||||
|
||||
|
||||
def mkgroupfold(dall, nfold, param, evals=(), fpreproc=None, shuffle=True):
|
||||
"""
|
||||
Make n folds for cross-validation maintaining groups
|
||||
:return: cross-validation folds
|
||||
"""
|
||||
# we have groups for pairwise ranking... get a list of the group indexes
|
||||
group_boundaries = dall.get_uint_info('group_ptr')
|
||||
group_sizes = np.diff(group_boundaries)
|
||||
|
||||
if shuffle is True:
|
||||
idx = np.random.permutation(len(group_sizes))
|
||||
else:
|
||||
idx = np.arange(len(group_sizes))
|
||||
# list by fold of test group indexes
|
||||
out_group_idset = np.array_split(idx, nfold)
|
||||
# list by fold of train group indexes
|
||||
in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
|
||||
for k in range(nfold)]
|
||||
# from the group indexes, convert them to row indexes
|
||||
in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset]
|
||||
out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset]
|
||||
|
||||
# build the folds by taking the appropriate slices
|
||||
ret = []
|
||||
for k in range(nfold):
|
||||
# perform the slicing using the indexes determined by the above methods
|
||||
dtrain = dall.slice(in_idset[k], allow_groups=True)
|
||||
dtrain.set_group(group_sizes[in_group_idset[k]])
|
||||
dtest = dall.slice(out_idset[k], allow_groups=True)
|
||||
dtest.set_group(group_sizes[out_group_idset[k]])
|
||||
# run preprocessing on the data set if needed
|
||||
if fpreproc is not None:
|
||||
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||
else:
|
||||
tparam = param
|
||||
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
|
||||
ret.append(CVPack(dtrain, dtest, plst))
|
||||
return ret
|
||||
|
||||
|
||||
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
folds=None, shuffle=True):
|
||||
"""
|
||||
@ -243,16 +293,17 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
np.random.seed(seed)
|
||||
|
||||
if stratified is False and folds is None:
|
||||
# Do standard k-fold cross validation
|
||||
# Do standard k-fold cross validation. Automatically determine the folds.
|
||||
if len(dall.get_uint_info('group_ptr')) > 1:
|
||||
return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle)
|
||||
|
||||
if shuffle is True:
|
||||
idx = np.random.permutation(dall.num_row())
|
||||
else:
|
||||
idx = np.arange(dall.num_row())
|
||||
out_idset = np.array_split(idx, nfold)
|
||||
in_idset = [
|
||||
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
|
||||
for k in range(nfold)
|
||||
]
|
||||
in_idset = [np.concatenate([out_idset[i] for i in range(nfold) if k != i])
|
||||
for k in range(nfold)]
|
||||
elif folds is not None:
|
||||
# Use user specified custom split using indices
|
||||
try:
|
||||
@ -274,6 +325,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
|
||||
ret = []
|
||||
for k in range(nfold):
|
||||
# perform the slicing using the indexes determined by the above methods
|
||||
dtrain = dall.slice(in_idset[k])
|
||||
dtest = dall.slice(out_idset[k])
|
||||
# run preprocessing on the data set if needed
|
||||
|
||||
@ -674,6 +674,14 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out) {
|
||||
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out,
|
||||
int allow_groups) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
@ -682,8 +690,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
||||
data::SimpleCSRSource& ret = *source;
|
||||
|
||||
CHECK_EQ(src.info.group_ptr_.size(), 0U)
|
||||
if (!allow_groups) {
|
||||
CHECK_EQ(src.info.group_ptr_.size(), 0U)
|
||||
<< "slice does not support group structure";
|
||||
}
|
||||
|
||||
ret.Clear();
|
||||
ret.info.num_row_ = len;
|
||||
@ -814,11 +824,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
const std::vector<unsigned>* vec = nullptr;
|
||||
if (!std::strcmp(field, "root_index")) {
|
||||
vec = &info.root_index_;
|
||||
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
} else if (!std::strcmp(field, "group_ptr")) {
|
||||
vec = &info.group_ptr_;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown uint field name " << field;
|
||||
LOG(FATAL) << "Unknown comp uint field name " << field
|
||||
<< " with comparison " << std::strcmp(field, "group_ptr");
|
||||
}
|
||||
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,16 @@
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import xgboost
|
||||
import sys
|
||||
import os
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
import unittest
|
||||
import itertools
|
||||
import glob
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
|
||||
def test_ranking_with_unweighted_data():
|
||||
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
|
||||
@ -63,3 +73,110 @@ def test_ranking_with_weighted_data():
|
||||
# the ranking predictor will first try to correctly sort the last query group
|
||||
# before correctly sorting other groups.
|
||||
assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:]))
|
||||
|
||||
|
||||
class TestRanking(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
# download the test data
|
||||
cls.dpath = 'demo/rank/'
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = cls.dpath + '/MQ2008.zip'
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
|
||||
x_train, y_train, qid_train, x_test, y_test, qid_test, x_valid, y_valid, qid_valid = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
# instantiate the matrices
|
||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
|
||||
cls.dtest = xgboost.DMatrix(x_test, y_test)
|
||||
# set the group counts from the query IDs
|
||||
cls.dtrain.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_train)])
|
||||
cls.dtest.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_test)])
|
||||
cls.dvalid.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_valid)])
|
||||
# save the query IDs for testing
|
||||
cls.qid_train = qid_train
|
||||
cls.qid_test = qid_test
|
||||
cls.qid_valid = qid_valid
|
||||
|
||||
# model training parameters
|
||||
cls.params = {'objective': 'rank:pairwise',
|
||||
'booster': 'gbtree',
|
||||
'silent': 0,
|
||||
'eval_metric': ['ndcg']
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""
|
||||
Cleanup test artifacts from download and unpacking
|
||||
:return:
|
||||
"""
|
||||
os.remove(cls.dpath + "MQ2008.zip")
|
||||
shutil.rmtree(cls.dpath + "MQ2008")
|
||||
|
||||
def test_training(self):
|
||||
"""
|
||||
Train an XGBoost ranking model
|
||||
"""
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')]
|
||||
bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, evals=watchlist)
|
||||
assert bst.best_score > 0.98
|
||||
|
||||
def test_cv(self):
|
||||
"""
|
||||
Test cross-validation with a group specified
|
||||
"""
|
||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
self.assertSetEqual(set(cv.keys()), {'test-ndcg-mean', 'train-ndcg-mean', 'test-ndcg-std', 'train-ndcg-std'},
|
||||
"CV results dict key mismatch")
|
||||
|
||||
def test_cv_no_shuffle(self):
|
||||
"""
|
||||
Test cross-validation with a group specified
|
||||
"""
|
||||
cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500,
|
||||
early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == 4
|
||||
|
||||
def test_get_group(self):
|
||||
"""
|
||||
Retrieve the group number from the dmatrix
|
||||
"""
|
||||
# control that should work
|
||||
self.dtrain.get_uint_info('root_index')
|
||||
# test the new getter
|
||||
self.dtrain.get_uint_info('group_ptr')
|
||||
|
||||
for d, qid in [(self.dtrain, self.qid_train),
|
||||
(self.dvalid, self.qid_valid),
|
||||
(self.dtest, self.qid_test)]:
|
||||
# size of each group
|
||||
group_sizes = np.array([len(list(items))
|
||||
for _key, items in itertools.groupby(qid)])
|
||||
# indexes of group boundaries
|
||||
group_limits = d.get_uint_info('group_ptr')
|
||||
assert len(group_limits) == len(group_sizes)+1
|
||||
assert np.array_equal(np.diff(group_limits), group_sizes)
|
||||
assert np.array_equal(
|
||||
group_sizes, np.diff(d.get_uint_info('group_ptr')))
|
||||
assert np.array_equal(group_sizes, np.diff(d.get_uint_info('group_ptr')))
|
||||
assert np.array_equal(group_limits, d.get_uint_info('group_ptr'))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user