diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f14767092..b8cb9607c 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -89,3 +89,5 @@ List of Contributors * [Sam Wilkinson](https://samwilkinson.io) * [Matthew Jones](https://github.com/mt-jones) * [Jiaxiang Li](https://github.com/JiaxiangBU) +* [Bryan Woods](https://github.com/bryan-woods) + - Bryan added support for cross-validation for the ranking objective diff --git a/Jenkinsfile b/Jenkinsfile index 8608096ab..c347e8eb9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -340,6 +340,8 @@ def TestR(args) { sh """ ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_test_rpkg.sh """ + // Save error log, if any + archiveArtifacts artifacts: "xgboost.Rcheck/00install.out", allowEmptyArchive: true deleteDir() } } diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 5281ef6bd..2ca9b8e57 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -136,9 +136,10 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { idxvec[i] = INTEGER(idxset)[i] - 1; } DMatrixHandle res; - CHECK_CALL(XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle), - BeginPtr(idxvec), len, - &res)); + CHECK_CALL(XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle), + BeginPtr(idxvec), len, + &res, + 0)); ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_API_END(); diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 3328aba88..b7e1ec5ec 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -221,6 +221,20 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int *idxset, bst_ulong len, DMatrixHandle *out); +/*! + * \brief create a new dmatrix from sliced content of existing matrix + * \param handle instance of data matrix to be sliced + * \param idxset index set + * \param len length of index set + * \param out a sliced new matrix + * \param allow_groups allow slicing of an array with groups + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, + const int *idxset, + bst_ulong len, + DMatrixHandle *out, + int allow_groups); /*! * \brief free space in data matrix * \return 0 when success, -1 when failure happens diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 285793879..e29beabfe 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -279,7 +279,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat jint* indexset = jenv->GetIntArrayElements(jindexset, 0); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); - jint ret = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result); + // default to not allowing slicing with group ID specified -- feel free to add if necessary + jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0); setHandle(jenv, jout, result); //release jenv->ReleaseIntArrayElements(jindexset, indexset, 0); diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 75764806f..dbf34e051 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -795,13 +795,15 @@ class DMatrix(object): ctypes.byref(ret))) return ret.value - def slice(self, rindex): + def slice(self, rindex, allow_groups=False): """Slice the DMatrix and return a new DMatrix that only contains `rindex`. Parameters ---------- rindex : list List of indices to be selected. + allow_groups : boolean + Allow slicing of a matrix with a groups attribute Returns ------- @@ -811,10 +813,11 @@ class DMatrix(object): res = DMatrix(None, feature_names=self.feature_names, feature_types=self.feature_types) res.handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixSliceDMatrix(self.handle, - c_array(ctypes.c_int, rindex), - c_bst_ulong(len(rindex)), - ctypes.byref(res.handle))) + _check_call(_LIB.XGDMatrixSliceDMatrixEx(self.handle, + c_array(ctypes.c_int, rindex), + c_bst_ulong(len(rindex)), + ctypes.byref(res.handle), + ctypes.c_int(1 if allow_groups else 0))) return res @property diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 71c96736a..0bba8341e 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -234,6 +234,56 @@ class CVPack(object): return self.bst.eval_set(self.watchlist, iteration, feval) +def groups_to_rows(groups, boundaries): + """ + Given group row boundaries, convert ground indexes to row indexes + :param groups: list of groups for testing + :param boundaries: rows index limits of each group + :return: row in group + """ + return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups]) + + +def mkgroupfold(dall, nfold, param, evals=(), fpreproc=None, shuffle=True): + """ + Make n folds for cross-validation maintaining groups + :return: cross-validation folds + """ + # we have groups for pairwise ranking... get a list of the group indexes + group_boundaries = dall.get_uint_info('group_ptr') + group_sizes = np.diff(group_boundaries) + + if shuffle is True: + idx = np.random.permutation(len(group_sizes)) + else: + idx = np.arange(len(group_sizes)) + # list by fold of test group indexes + out_group_idset = np.array_split(idx, nfold) + # list by fold of train group indexes + in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i]) + for k in range(nfold)] + # from the group indexes, convert them to row indexes + in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset] + out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset] + + # build the folds by taking the appropriate slices + ret = [] + for k in range(nfold): + # perform the slicing using the indexes determined by the above methods + dtrain = dall.slice(in_idset[k], allow_groups=True) + dtrain.set_group(group_sizes[in_group_idset[k]]) + dtest = dall.slice(out_idset[k], allow_groups=True) + dtest.set_group(group_sizes[out_group_idset[k]]) + # run preprocessing on the data set if needed + if fpreproc is not None: + dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy()) + else: + tparam = param + plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals] + ret.append(CVPack(dtrain, dtest, plst)) + return ret + + def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, folds=None, shuffle=True): """ @@ -243,16 +293,17 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, np.random.seed(seed) if stratified is False and folds is None: - # Do standard k-fold cross validation + # Do standard k-fold cross validation. Automatically determine the folds. + if len(dall.get_uint_info('group_ptr')) > 1: + return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle) + if shuffle is True: idx = np.random.permutation(dall.num_row()) else: idx = np.arange(dall.num_row()) out_idset = np.array_split(idx, nfold) - in_idset = [ - np.concatenate([out_idset[i] for i in range(nfold) if k != i]) - for k in range(nfold) - ] + in_idset = [np.concatenate([out_idset[i] for i in range(nfold) if k != i]) + for k in range(nfold)] elif folds is not None: # Use user specified custom split using indices try: @@ -274,6 +325,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, ret = [] for k in range(nfold): + # perform the slicing using the indexes determined by the above methods dtrain = dall.slice(in_idset[k]) dtest = dall.slice(out_idset[k]) # run preprocessing on the data set if needed diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ac9c35c4c..3d85b94c3 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -674,6 +674,14 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int* idxset, xgboost::bst_ulong len, DMatrixHandle* out) { + return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0); +} + +XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, + const int* idxset, + xgboost::bst_ulong len, + DMatrixHandle* out, + int allow_groups) { std::unique_ptr source(new data::SimpleCSRSource()); API_BEGIN(); @@ -682,8 +690,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, src.CopyFrom(static_cast*>(handle)->get()); data::SimpleCSRSource& ret = *source; - CHECK_EQ(src.info.group_ptr_.size(), 0U) + if (!allow_groups) { + CHECK_EQ(src.info.group_ptr_.size(), 0U) << "slice does not support group structure"; + } ret.Clear(); ret.info.num_row_ = len; @@ -814,11 +824,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, const std::vector* vec = nullptr; if (!std::strcmp(field, "root_index")) { vec = &info.root_index_; - *out_len = static_cast(vec->size()); - *out_dptr = dmlc::BeginPtr(*vec); + } else if (!std::strcmp(field, "group_ptr")) { + vec = &info.group_ptr_; } else { - LOG(FATAL) << "Unknown uint field name " << field; + LOG(FATAL) << "Unknown comp uint field name " << field + << " with comparison " << std::strcmp(field, "group_ptr"); } + *out_len = static_cast(vec->size()); + *out_dptr = dmlc::BeginPtr(*vec); API_END(); } diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index d42d67f14..331686b2d 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -1,6 +1,16 @@ import numpy as np from scipy.sparse import csr_matrix import xgboost +import sys +import os +from sklearn.datasets import load_svmlight_files +import unittest +import itertools +import glob +import shutil +import urllib.request +import zipfile + def test_ranking_with_unweighted_data(): Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17]) @@ -63,3 +73,110 @@ def test_ranking_with_weighted_data(): # the ranking predictor will first try to correctly sort the last query group # before correctly sorting other groups. assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:])) + + +class TestRanking(unittest.TestCase): + + @classmethod + def setUpClass(cls): + """ + Download and setup the test fixtures + """ + # download the test data + cls.dpath = 'demo/rank/' + src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' + target = cls.dpath + '/MQ2008.zip' + urllib.request.urlretrieve(url=src, filename=target) + + with zipfile.ZipFile(target, 'r') as f: + f.extractall(path=cls.dpath) + + x_train, y_train, qid_train, x_test, y_test, qid_test, x_valid, y_valid, qid_valid = load_svmlight_files( + (cls.dpath + "MQ2008/Fold1/train.txt", + cls.dpath + "MQ2008/Fold1/test.txt", + cls.dpath + "MQ2008/Fold1/vali.txt"), + query_id=True, zero_based=False) + # instantiate the matrices + cls.dtrain = xgboost.DMatrix(x_train, y_train) + cls.dvalid = xgboost.DMatrix(x_valid, y_valid) + cls.dtest = xgboost.DMatrix(x_test, y_test) + # set the group counts from the query IDs + cls.dtrain.set_group([len(list(items)) + for _key, items in itertools.groupby(qid_train)]) + cls.dtest.set_group([len(list(items)) + for _key, items in itertools.groupby(qid_test)]) + cls.dvalid.set_group([len(list(items)) + for _key, items in itertools.groupby(qid_valid)]) + # save the query IDs for testing + cls.qid_train = qid_train + cls.qid_test = qid_test + cls.qid_valid = qid_valid + + # model training parameters + cls.params = {'objective': 'rank:pairwise', + 'booster': 'gbtree', + 'silent': 0, + 'eval_metric': ['ndcg'] + } + + @classmethod + def tearDownClass(cls): + """ + Cleanup test artifacts from download and unpacking + :return: + """ + os.remove(cls.dpath + "MQ2008.zip") + shutil.rmtree(cls.dpath + "MQ2008") + + def test_training(self): + """ + Train an XGBoost ranking model + """ + # specify validations set to watch performance + watchlist = [(self.dtest, 'eval'), (self.dtrain, 'train')] + bst = xgboost.train(self.params, self.dtrain, num_boost_round=2500, + early_stopping_rounds=10, evals=watchlist) + assert bst.best_score > 0.98 + + def test_cv(self): + """ + Test cross-validation with a group specified + """ + cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500, + early_stopping_rounds=10, nfold=10, as_pandas=False) + assert isinstance(cv, dict) + self.assertSetEqual(set(cv.keys()), {'test-ndcg-mean', 'train-ndcg-mean', 'test-ndcg-std', 'train-ndcg-std'}, + "CV results dict key mismatch") + + def test_cv_no_shuffle(self): + """ + Test cross-validation with a group specified + """ + cv = xgboost.cv(self.params, self.dtrain, num_boost_round=2500, + early_stopping_rounds=10, shuffle=False, nfold=10, as_pandas=False) + assert isinstance(cv, dict) + assert len(cv) == 4 + + def test_get_group(self): + """ + Retrieve the group number from the dmatrix + """ + # control that should work + self.dtrain.get_uint_info('root_index') + # test the new getter + self.dtrain.get_uint_info('group_ptr') + + for d, qid in [(self.dtrain, self.qid_train), + (self.dvalid, self.qid_valid), + (self.dtest, self.qid_test)]: + # size of each group + group_sizes = np.array([len(list(items)) + for _key, items in itertools.groupby(qid)]) + # indexes of group boundaries + group_limits = d.get_uint_info('group_ptr') + assert len(group_limits) == len(group_sizes)+1 + assert np.array_equal(np.diff(group_limits), group_sizes) + assert np.array_equal( + group_sizes, np.diff(d.get_uint_info('group_ptr'))) + assert np.array_equal(group_sizes, np.diff(d.get_uint_info('group_ptr'))) + assert np.array_equal(group_limits, d.get_uint_info('group_ptr'))