Move feature names and types of DMatrix from Python to C++. (#5858)
* Add thread local return entry for DMatrix. * Save feature name and feature type in binary file. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -283,6 +283,38 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
const char **c_info,
|
||||
const xgboost::bst_ulong size) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||
info.SetFeatureInfo(field, c_info, size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
xgboost::bst_ulong *len,
|
||||
const char ***out_features) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto m = *static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||
|
||||
std::vector<const char *> &charp_vecs = m->GetThreadLocal().ret_vec_charp;
|
||||
std::vector<std::string> &str_vecs = m->GetThreadLocal().ret_vec_str;
|
||||
|
||||
info.GetFeatureInfo(field, &str_vecs);
|
||||
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
}
|
||||
*out_features = dmlc::BeginPtr(charp_vecs);
|
||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
@@ -301,22 +333,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||
const std::vector<bst_float>* vec = nullptr;
|
||||
if (!std::strcmp(field, "label")) {
|
||||
vec = &info.labels_.HostVector();
|
||||
} else if (!std::strcmp(field, "weight")) {
|
||||
vec = &info.weights_.HostVector();
|
||||
} else if (!std::strcmp(field, "base_margin")) {
|
||||
vec = &info.base_margin_.HostVector();
|
||||
} else if (!std::strcmp(field, "label_lower_bound")) {
|
||||
vec = &info.labels_lower_bound_.HostVector();
|
||||
} else if (!std::strcmp(field, "label_upper_bound")) {
|
||||
vec = &info.labels_upper_bound_.HostVector();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown float field name " << field;
|
||||
}
|
||||
*out_len = static_cast<xgboost::bst_ulong>(vec->size()); // NOLINT
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
info.GetInfo(field, out_len, DataType::kFloat32, reinterpret_cast<void const**>(out_dptr));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -327,14 +344,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||
const std::vector<unsigned>* vec = nullptr;
|
||||
if (!std::strcmp(field, "group_ptr")) {
|
||||
vec = &info.group_ptr_;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown uint field name " << field;
|
||||
}
|
||||
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
info.GetInfo(field, out_len, DataType::kUInt32, reinterpret_cast<void const**>(out_dptr));
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user