[breaking] Save booster feature info in JSON, remove feature name generation. (#6605)
* Save feature info in booster in JSON model. * [breaking] Remove automatic feature name generation in `DMatrix`. This PR is to enable reliable feature validation in Python package.
This commit is contained in:
@@ -256,6 +256,11 @@ class LearnerConfiguration : public Learner {
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// Stores information like best-iteration for early stopping.
|
||||
std::map<std::string, std::string> attributes_;
|
||||
// Name of each feature, usually set from DMatrix.
|
||||
std::vector<std::string> feature_names_;
|
||||
// Type of each feature, usually set from DMatrix.
|
||||
std::vector<std::string> feature_types_;
|
||||
|
||||
common::Monitor monitor_;
|
||||
LearnerModelParamLegacy mparam_;
|
||||
LearnerModelParam learner_model_param_;
|
||||
@@ -460,6 +465,23 @@ class LearnerConfiguration : public Learner {
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetFeatureNames(std::vector<std::string> const& fn) override {
|
||||
feature_names_ = fn;
|
||||
}
|
||||
|
||||
void GetFeatureNames(std::vector<std::string>* fn) const override {
|
||||
*fn = feature_names_;
|
||||
}
|
||||
|
||||
void SetFeatureTypes(std::vector<std::string> const& ft) override {
|
||||
this->feature_types_ = ft;
|
||||
}
|
||||
|
||||
void GetFeatureTypes(std::vector<std::string>* p_ft) const override {
|
||||
auto& ft = *p_ft;
|
||||
ft = this->feature_types_;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetAttrNames() const override {
|
||||
std::vector<std::string> out;
|
||||
for (auto const& kv : attributes_) {
|
||||
@@ -666,6 +688,25 @@ class LearnerIO : public LearnerConfiguration {
|
||||
attributes_[kv.first] = get<String const>(kv.second);
|
||||
}
|
||||
|
||||
// feature names and types are saved in xgboost 1.4
|
||||
auto it = learner.find("feature_names");
|
||||
if (it != learner.cend()) {
|
||||
auto const &feature_names = get<Array const>(it->second);
|
||||
feature_names_.clear();
|
||||
for (auto const &name : feature_names) {
|
||||
feature_names_.emplace_back(get<String const>(name));
|
||||
}
|
||||
}
|
||||
it = learner.find("feature_types");
|
||||
if (it != learner.cend()) {
|
||||
auto const &feature_types = get<Array const>(it->second);
|
||||
feature_types_.clear();
|
||||
for (auto const &name : feature_types) {
|
||||
auto type = get<String const>(name);
|
||||
feature_types_.emplace_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
this->need_configuration_ = true;
|
||||
}
|
||||
|
||||
@@ -691,6 +732,17 @@ class LearnerIO : public LearnerConfiguration {
|
||||
for (auto const& kv : attributes_) {
|
||||
learner["attributes"][kv.first] = String(kv.second);
|
||||
}
|
||||
|
||||
learner["feature_names"] = Array();
|
||||
auto& feature_names = get<Array>(learner["feature_names"]);
|
||||
for (auto const& name : feature_names_) {
|
||||
feature_names.emplace_back(name);
|
||||
}
|
||||
learner["feature_types"] = Array();
|
||||
auto& feature_types = get<Array>(learner["feature_types"]);
|
||||
for (auto const& type : feature_types_) {
|
||||
feature_types.emplace_back(type);
|
||||
}
|
||||
}
|
||||
// About to be deprecated by JSON format
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
|
||||
Reference in New Issue
Block a user