[breaking] Save booster feature info in JSON, remove feature name generation. (#6605)

* Save feature info in booster in JSON model.
* [breaking] Remove automatic feature name generation in `DMatrix`.

This PR is to enable reliable feature validation in Python package.
This commit is contained in:
Jiaming Yuan
2021-02-25 18:54:16 +08:00
committed by GitHub
parent b6167cd2ff
commit 9da2287ab8
12 changed files with 363 additions and 36 deletions

View File

@@ -1022,5 +1022,50 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field,
const char **features,
const xgboost::bst_ulong size) {
API_BEGIN();
CHECK_HANDLE();
auto *learner = static_cast<Learner *>(handle);
std::vector<std::string> feature_info;
for (size_t i = 0; i < size; ++i) {
feature_info.emplace_back(features[i]);
}
if (!std::strcmp(field, "feature_name")) {
learner->SetFeatureNames(feature_info);
} else if (!std::strcmp(field, "feature_type")) {
learner->SetFeatureTypes(feature_info);
} else {
LOG(FATAL) << "Unknown field for Booster feature info:" << field;
}
API_END();
}
XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
xgboost::bst_ulong *len,
const char ***out_features) {
API_BEGIN();
CHECK_HANDLE();
auto const *learner = static_cast<Learner const *>(handle);
std::vector<const char *> &charp_vecs =
learner->GetThreadLocal().ret_vec_charp;
std::vector<std::string> &str_vecs = learner->GetThreadLocal().ret_vec_str;
if (!std::strcmp(field, "feature_name")) {
learner->GetFeatureNames(&str_vecs);
} else if (!std::strcmp(field, "feature_type")) {
learner->GetFeatureTypes(&str_vecs);
} else {
LOG(FATAL) << "Unknown field for Booster feature info:" << field;
}
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {
charp_vecs[i] = str_vecs[i].c_str();
}
*out_features = dmlc::BeginPtr(charp_vecs);
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
API_END();
}
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();