Add support for cross-validation using query ID (#4474)

* adding support for matrix slicing with query ID for cross-validation

* hail mary test of unrar installation for windows tests

* trying to modify tests to run in Github CI

* Remove dependency on wget and unrar

* Save error log from R test

* Relax assertion in test_training

* Use int instead of bool in C function interface

* Revise R interface

* Add XGDMatrixSliceDMatrixEx and keep old XGDMatrixSliceDMatrix for API compatibility
This commit is contained in:
Bryan Woods
2019-05-23 19:45:02 +02:00
committed by Philip Hyunsu Cho
parent 5a567ec249
commit 278562db13
9 changed files with 223 additions and 18 deletions

View File

@@ -674,6 +674,14 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
DMatrixHandle* out) {
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
}
XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
DMatrixHandle* out,
int allow_groups) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN();
@@ -682,8 +690,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source;
CHECK_EQ(src.info.group_ptr_.size(), 0U)
if (!allow_groups) {
CHECK_EQ(src.info.group_ptr_.size(), 0U)
<< "slice does not support group structure";
}
ret.Clear();
ret.info.num_row_ = len;
@@ -814,11 +824,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "root_index")) {
vec = &info.root_index_;
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
*out_dptr = dmlc::BeginPtr(*vec);
} else if (!std::strcmp(field, "group_ptr")) {
vec = &info.group_ptr_;
} else {
LOG(FATAL) << "Unknown uint field name " << field;
LOG(FATAL) << "Unknown comp uint field name " << field
<< " with comparison " << std::strcmp(field, "group_ptr");
}
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
*out_dptr = dmlc::BeginPtr(*vec);
API_END();
}