[breaking] Save booster feature info in JSON, remove feature name generation. (#6605)

* Save feature info in booster in JSON model.
* [breaking] Remove automatic feature name generation in `DMatrix`.

This PR is to enable reliable feature validation in Python package.
This commit is contained in:
Jiaming Yuan
2021-02-25 18:54:16 +08:00
committed by GitHub
parent b6167cd2ff
commit 9da2287ab8
12 changed files with 363 additions and 36 deletions

View File

@@ -256,6 +256,11 @@ class LearnerConfiguration : public Learner {
std::map<std::string, std::string> cfg_;
// Stores information like best-iteration for early stopping.
std::map<std::string, std::string> attributes_;
// Name of each feature, usually set from DMatrix.
std::vector<std::string> feature_names_;
// Type of each feature, usually set from DMatrix.
std::vector<std::string> feature_types_;
common::Monitor monitor_;
LearnerModelParamLegacy mparam_;
LearnerModelParam learner_model_param_;
@@ -460,6 +465,23 @@ class LearnerConfiguration : public Learner {
return true;
}
void SetFeatureNames(std::vector<std::string> const& fn) override {
feature_names_ = fn;
}
void GetFeatureNames(std::vector<std::string>* fn) const override {
*fn = feature_names_;
}
void SetFeatureTypes(std::vector<std::string> const& ft) override {
this->feature_types_ = ft;
}
void GetFeatureTypes(std::vector<std::string>* p_ft) const override {
auto& ft = *p_ft;
ft = this->feature_types_;
}
std::vector<std::string> GetAttrNames() const override {
std::vector<std::string> out;
for (auto const& kv : attributes_) {
@@ -666,6 +688,25 @@ class LearnerIO : public LearnerConfiguration {
attributes_[kv.first] = get<String const>(kv.second);
}
// feature names and types are saved in xgboost 1.4
auto it = learner.find("feature_names");
if (it != learner.cend()) {
auto const &feature_names = get<Array const>(it->second);
feature_names_.clear();
for (auto const &name : feature_names) {
feature_names_.emplace_back(get<String const>(name));
}
}
it = learner.find("feature_types");
if (it != learner.cend()) {
auto const &feature_types = get<Array const>(it->second);
feature_types_.clear();
for (auto const &name : feature_types) {
auto type = get<String const>(name);
feature_types_.emplace_back(type);
}
}
this->need_configuration_ = true;
}
@@ -691,6 +732,17 @@ class LearnerIO : public LearnerConfiguration {
for (auto const& kv : attributes_) {
learner["attributes"][kv.first] = String(kv.second);
}
learner["feature_names"] = Array();
auto& feature_names = get<Array>(learner["feature_names"]);
for (auto const& name : feature_names_) {
feature_names.emplace_back(name);
}
learner["feature_types"] = Array();
auto& feature_types = get<Array>(learner["feature_types"]);
for (auto const& type : feature_types_) {
feature_types.emplace_back(type);
}
}
// About to be deprecated by JSON format
void LoadModel(dmlc::Stream* fi) override {