[breaking] Save booster feature info in JSON, remove feature name generation. (#6605)
* Save feature info in booster in JSON model. * [breaking] Remove automatic feature name generation in `DMatrix`. This PR is to enable reliable feature validation in Python package.
This commit is contained in:
@@ -1022,5 +1022,50 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
const char **features,
|
||||
const xgboost::bst_ulong size) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
std::vector<std::string> feature_info;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_info.emplace_back(features[i]);
|
||||
}
|
||||
if (!std::strcmp(field, "feature_name")) {
|
||||
learner->SetFeatureNames(feature_info);
|
||||
} else if (!std::strcmp(field, "feature_type")) {
|
||||
learner->SetFeatureTypes(feature_info);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown field for Booster feature info:" << field;
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
xgboost::bst_ulong *len,
|
||||
const char ***out_features) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto const *learner = static_cast<Learner const *>(handle);
|
||||
std::vector<const char *> &charp_vecs =
|
||||
learner->GetThreadLocal().ret_vec_charp;
|
||||
std::vector<std::string> &str_vecs = learner->GetThreadLocal().ret_vec_str;
|
||||
if (!std::strcmp(field, "feature_name")) {
|
||||
learner->GetFeatureNames(&str_vecs);
|
||||
} else if (!std::strcmp(field, "feature_type")) {
|
||||
learner->GetFeatureTypes(&str_vecs);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown field for Booster feature info:" << field;
|
||||
}
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
}
|
||||
*out_features = dmlc::BeginPtr(charp_vecs);
|
||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
@@ -256,6 +256,11 @@ class LearnerConfiguration : public Learner {
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// Stores information like best-iteration for early stopping.
|
||||
std::map<std::string, std::string> attributes_;
|
||||
// Name of each feature, usually set from DMatrix.
|
||||
std::vector<std::string> feature_names_;
|
||||
// Type of each feature, usually set from DMatrix.
|
||||
std::vector<std::string> feature_types_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
LearnerModelParamLegacy mparam_;
|
||||
LearnerModelParam learner_model_param_;
|
||||
@@ -460,6 +465,23 @@ class LearnerConfiguration : public Learner {
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetFeatureNames(std::vector<std::string> const& fn) override {
|
||||
feature_names_ = fn;
|
||||
}
|
||||
|
||||
void GetFeatureNames(std::vector<std::string>* fn) const override {
|
||||
*fn = feature_names_;
|
||||
}
|
||||
|
||||
void SetFeatureTypes(std::vector<std::string> const& ft) override {
|
||||
this->feature_types_ = ft;
|
||||
}
|
||||
|
||||
void GetFeatureTypes(std::vector<std::string>* p_ft) const override {
|
||||
auto& ft = *p_ft;
|
||||
ft = this->feature_types_;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetAttrNames() const override {
|
||||
std::vector<std::string> out;
|
||||
for (auto const& kv : attributes_) {
|
||||
@@ -666,6 +688,25 @@ class LearnerIO : public LearnerConfiguration {
|
||||
attributes_[kv.first] = get<String const>(kv.second);
|
||||
}
|
||||
|
||||
// feature names and types are saved in xgboost 1.4
|
||||
auto it = learner.find("feature_names");
|
||||
if (it != learner.cend()) {
|
||||
auto const &feature_names = get<Array const>(it->second);
|
||||
feature_names_.clear();
|
||||
for (auto const &name : feature_names) {
|
||||
feature_names_.emplace_back(get<String const>(name));
|
||||
}
|
||||
}
|
||||
it = learner.find("feature_types");
|
||||
if (it != learner.cend()) {
|
||||
auto const &feature_types = get<Array const>(it->second);
|
||||
feature_types_.clear();
|
||||
for (auto const &name : feature_types) {
|
||||
auto type = get<String const>(name);
|
||||
feature_types_.emplace_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
this->need_configuration_ = true;
|
||||
}
|
||||
|
||||
@@ -691,6 +732,17 @@ class LearnerIO : public LearnerConfiguration {
|
||||
for (auto const& kv : attributes_) {
|
||||
learner["attributes"][kv.first] = String(kv.second);
|
||||
}
|
||||
|
||||
learner["feature_names"] = Array();
|
||||
auto& feature_names = get<Array>(learner["feature_names"]);
|
||||
for (auto const& name : feature_names_) {
|
||||
feature_names.emplace_back(name);
|
||||
}
|
||||
learner["feature_types"] = Array();
|
||||
auto& feature_types = get<Array>(learner["feature_types"]);
|
||||
for (auto const& type : feature_types_) {
|
||||
feature_types.emplace_back(type);
|
||||
}
|
||||
}
|
||||
// About to be deprecated by JSON format
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
|
||||
@@ -385,7 +385,7 @@ class JsonGenerator : public TreeGenerator {
|
||||
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
|
||||
auto cond = tree[nid].SplitCond();
|
||||
static std::string const kNodeTemplate =
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
|
||||
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||
R"I("missing": {missing})I";
|
||||
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||
|
||||
Reference in New Issue
Block a user