[breaking] Save booster feature info in JSON, remove feature name generation. (#6605)
* Save feature info in booster in JSON model. * [breaking] Remove automatic feature name generation in `DMatrix`. This PR is to enable reliable feature validation in Python package.
This commit is contained in:
@@ -1022,5 +1022,50 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
const char **features,
|
||||
const xgboost::bst_ulong size) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
std::vector<std::string> feature_info;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_info.emplace_back(features[i]);
|
||||
}
|
||||
if (!std::strcmp(field, "feature_name")) {
|
||||
learner->SetFeatureNames(feature_info);
|
||||
} else if (!std::strcmp(field, "feature_type")) {
|
||||
learner->SetFeatureTypes(feature_info);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown field for Booster feature info:" << field;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
xgboost::bst_ulong *len,
|
||||
const char ***out_features) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto const *learner = static_cast<Learner const *>(handle);
|
||||
std::vector<const char *> &charp_vecs =
|
||||
learner->GetThreadLocal().ret_vec_charp;
|
||||
std::vector<std::string> &str_vecs = learner->GetThreadLocal().ret_vec_str;
|
||||
if (!std::strcmp(field, "feature_name")) {
|
||||
learner->GetFeatureNames(&str_vecs);
|
||||
} else if (!std::strcmp(field, "feature_type")) {
|
||||
learner->GetFeatureTypes(&str_vecs);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown field for Booster feature info:" << field;
|
||||
}
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
}
|
||||
*out_features = dmlc::BeginPtr(charp_vecs);
|
||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user