Move feature names and types of DMatrix from Python to C++. (#5858)

* Add thread local return entry for DMatrix.
* Save feature name and feature type in binary file.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-07-07 09:40:13 +08:00
committed by GitHub
parent 4b0852ee41
commit 93c44a9a64
12 changed files with 451 additions and 84 deletions

View File

@@ -283,6 +283,38 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
API_END();
}
XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field,
const char **c_info,
const xgboost::bst_ulong size) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
info.SetFeatureInfo(field, c_info, size);
API_END();
}
XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
xgboost::bst_ulong *len,
const char ***out_features) {
API_BEGIN();
CHECK_HANDLE();
auto m = *static_cast<std::shared_ptr<DMatrix>*>(handle);
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
std::vector<const char *> &charp_vecs = m->GetThreadLocal().ret_vec_charp;
std::vector<std::string> &str_vecs = m->GetThreadLocal().ret_vec_str;
info.GetFeatureInfo(field, &str_vecs);
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {
charp_vecs[i] = str_vecs[i].c_str();
}
*out_features = dmlc::BeginPtr(charp_vecs);
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
API_END();
}
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
@@ -301,22 +333,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
API_BEGIN();
CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<bst_float>* vec = nullptr;
if (!std::strcmp(field, "label")) {
vec = &info.labels_.HostVector();
} else if (!std::strcmp(field, "weight")) {
vec = &info.weights_.HostVector();
} else if (!std::strcmp(field, "base_margin")) {
vec = &info.base_margin_.HostVector();
} else if (!std::strcmp(field, "label_lower_bound")) {
vec = &info.labels_lower_bound_.HostVector();
} else if (!std::strcmp(field, "label_upper_bound")) {
vec = &info.labels_upper_bound_.HostVector();
} else {
LOG(FATAL) << "Unknown float field name " << field;
}
*out_len = static_cast<xgboost::bst_ulong>(vec->size()); // NOLINT
*out_dptr = dmlc::BeginPtr(*vec);
info.GetInfo(field, out_len, DataType::kFloat32, reinterpret_cast<void const**>(out_dptr));
API_END();
}
@@ -327,14 +344,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
API_BEGIN();
CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "group_ptr")) {
vec = &info.group_ptr_;
} else {
LOG(FATAL) << "Unknown uint field name " << field;
}
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
*out_dptr = dmlc::BeginPtr(*vec);
info.GetInfo(field, out_len, DataType::kUInt32, reinterpret_cast<void const**>(out_dptr));
API_END();
}