Add support for cross-validation using query ID (#4474)
* adding support for matrix slicing with query ID for cross-validation * hail mary test of unrar installation for windows tests * trying to modify tests to run in Github CI * Remove dependency on wget and unrar * Save error log from R test * Relax assertion in test_training * Use int instead of bool in C function interface * Revise R interface * Add XGDMatrixSliceDMatrixEx and keep old XGDMatrixSliceDMatrix for API compatibility
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
5a567ec249
commit
278562db13
@@ -674,6 +674,14 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out) {
|
||||
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
xgboost::bst_ulong len,
|
||||
DMatrixHandle* out,
|
||||
int allow_groups) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
@@ -682,8 +690,10 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
||||
data::SimpleCSRSource& ret = *source;
|
||||
|
||||
CHECK_EQ(src.info.group_ptr_.size(), 0U)
|
||||
if (!allow_groups) {
|
||||
CHECK_EQ(src.info.group_ptr_.size(), 0U)
|
||||
<< "slice does not support group structure";
|
||||
}
|
||||
|
||||
ret.Clear();
|
||||
ret.info.num_row_ = len;
|
||||
@@ -814,11 +824,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
const std::vector<unsigned>* vec = nullptr;
|
||||
if (!std::strcmp(field, "root_index")) {
|
||||
vec = &info.root_index_;
|
||||
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
} else if (!std::strcmp(field, "group_ptr")) {
|
||||
vec = &info.group_ptr_;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown uint field name " << field;
|
||||
LOG(FATAL) << "Unknown comp uint field name " << field
|
||||
<< " with comparison " << std::strcmp(field, "group_ptr");
|
||||
}
|
||||
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user