Implement feature score in GBTree. (#7041)

* Categorical data support.
* Eliminate text parsing during feature score computation.
This commit is contained in:
Jiaming Yuan 2021-06-18 11:53:16 +08:00 committed by GitHub
parent dcd84b3979
commit 7dd29ffd47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 285 additions and 84 deletions

View File

@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015~2020 by Contributors
* Copyright (c) 2015~2021 by Contributors
* \file c_api.h
* \author Tianqi Chen
* \brief C API of XGBoost, used for interfacing to other languages.
@ -1193,4 +1193,28 @@ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field,
XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
bst_ulong *len,
const char ***out_features);
/*!
* \brief Calculate feature scores for tree models.
*
* \param handle An instance of Booster
* \param json_config Parameters for computing scores. Accepted JSON keys are:
* - importance_type: A JSON string with following possible values:
* * 'weight': the number of times a feature is used to split the data across all trees.
* * 'gain': the average gain across all splits the feature is used in.
* * 'cover': the average coverage across all splits the feature is used in.
* * 'total_gain': the total gain across all splits the feature is used in.
* * 'total_cover': the total coverage across all splits the feature is used in.
* - feature_map: An optional JSON string with URI or path to the feature map file.
*
* \param out_length Length of output arrays.
* \param out_features An array of string as feature names, ordered the same as output scores.
* \param out_scores An array of floating point as feature scores.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
bst_ulong *out_length,
const char ***out_features,
float **out_scores);
#endif // XGBOOST_C_API_H_

View File

@ -181,6 +181,12 @@ class GradientBooster : public Model, public Configurable {
virtual std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const = 0;
virtual void FeatureScore(std::string const &importance_type,
std::vector<bst_feature_t> *features,
std::vector<float> *scores) const {
LOG(FATAL) << "`feature_score` is not implemented for current booster.";
}
/*!
* \brief Whether the current booster uses GPU.
*/

View File

@ -152,6 +152,13 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;
/*!
* \brief Calculate feature score. See doc in C API for outputs.
*/
virtual void CalcFeatureScore(std::string const &importance_type,
std::vector<bst_feature_t> *features,
std::vector<float> *scores) = 0;
/*
* \brief Get number of boosted rounds from gradient booster.
*/

View File

@ -2191,7 +2191,9 @@ class Booster(object):
return self.get_score(fmap, importance_type='weight')
def get_score(self, fmap='', importance_type='weight'):
def get_score(
self, fmap: os.PathLike = '', importance_type: str = 'weight'
) -> Dict[str, float]:
"""Get feature importance of each feature.
Importance type can be defined as:
@ -2203,9 +2205,9 @@ class Booster(object):
.. note:: Feature importance is defined only for tree boosters
Feature importance is only defined when the decision tree model is chosen as base
learner (`booster=gbtree`). It is not defined for other base learner types, such
as linear learners (`booster=gblinear`).
Feature importance is only defined when the decision tree model is chosen as
base learner (`booster=gbtree` or `booster=dart`). It is not defined for other
base learner types, such as linear learners (`booster=gblinear`).
Parameters
----------
@ -2213,86 +2215,33 @@ class Booster(object):
The name of feature map file.
importance_type: str, default 'weight'
One of the importance types defined above.
Returns
-------
A map between feature names and their scores.
"""
fmap = os.fspath(os.path.expanduser(fmap))
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
raise ValueError('Feature importance is not defined for Booster type {}'
.format(self.booster))
allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover']
if importance_type not in allowed_importance_types:
msg = ("importance_type mismatch, got '{}', expected one of " +
repr(allowed_importance_types))
raise ValueError(msg.format(importance_type))
# if it's weight, then omap stores the number of missing values
if importance_type == 'weight':
# do a simpler tree dump to save time
trees = self.get_dump(fmap, with_stats=False)
fmap = {}
for tree in trees:
for line in tree.split('\n'):
# look for the opening square bracket
arr = line.split('[')
# if no opening bracket (leaf node), ignore this line
if len(arr) == 1:
continue
# extract feature name from string between []
fid = arr[1].split(']')[0].split('<')[0]
if fid not in fmap:
# if the feature hasn't been seen yet
fmap[fid] = 1
else:
fmap[fid] += 1
return fmap
average_over_splits = True
if importance_type == 'total_gain':
importance_type = 'gain'
average_over_splits = False
elif importance_type == 'total_cover':
importance_type = 'cover'
average_over_splits = False
trees = self.get_dump(fmap, with_stats=True)
importance_type += '='
fmap = {}
gmap = {}
for tree in trees:
for line in tree.split('\n'):
# look for the opening square bracket
arr = line.split('[')
# if no opening bracket (leaf node), ignore this line
if len(arr) == 1:
continue
# look for the closing bracket, extract only info within that bracket
fid = arr[1].split(']')
# extract gain or cover from string after closing bracket
g = float(fid[1].split(importance_type)[1].split(',')[0])
# extract feature name from string before closing bracket
fid = fid[0].split('<')[0]
if fid not in fmap:
# if the feature hasn't been seen yet
fmap[fid] = 1
gmap[fid] = g
else:
fmap[fid] += 1
gmap[fid] += g
# calculate average value (gain/cover) for each feature
if average_over_splits:
for fid in gmap:
gmap[fid] = gmap[fid] / fmap[fid]
return gmap
args = from_pystr_to_cstr(
json.dumps({"importance_type": importance_type, "feature_map": fmap})
)
features = ctypes.POINTER(ctypes.c_char_p)()
scores = ctypes.POINTER(ctypes.c_float)()
length = c_bst_ulong()
_check_call(
_LIB.XGBoosterFeatureScore(
self.handle,
args,
ctypes.byref(length),
ctypes.byref(features),
ctypes.byref(scores)
)
)
features_arr = from_cstr_to_pystr(features, length)
scores_arr = ctypes2numpy(scores, length.value, np.float32)
results = {}
for feat, score in zip(features_arr, scores_arr):
results[feat] = score
return results
def trees_to_dataframe(self, fmap=''):
"""Parse a boosted tree model text dump into a pandas DataFrame structure.

View File

@ -1098,5 +1098,47 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
API_END();
}
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle,
const char *json_config,
xgboost::bst_ulong* out_length,
const char ***out_features,
float **out_scores) {
API_BEGIN();
CHECK_HANDLE();
auto *learner = static_cast<Learner *>(handle);
auto config = Json::Load(StringView{json_config});
auto importance = get<String const>(config["importance_type"]);
std::string feature_map_uri;
if (!IsA<Null>(config["feature_map"])) {
feature_map_uri = get<String const>(config["feature_map"]);
}
FeatureMap feature_map = LoadFeatureMap(feature_map_uri);
auto& scores = learner->GetThreadLocal().ret_vec_float;
std::vector<bst_feature_t> features;
learner->CalcFeatureScore(importance, &features, &scores);
auto n_features = learner->GetNumFeature();
GenerateFeatureMap(learner, n_features, &feature_map);
CHECK_LE(features.size(), n_features);
auto& feature_names = learner->GetThreadLocal().ret_vec_str;
feature_names.resize(features.size());
auto& feature_names_c = learner->GetThreadLocal().ret_vec_charp;
feature_names_c.resize(features.size());
for (bst_feature_t i = 0; i < features.size(); ++i) {
feature_names[i] = feature_map.Name(features[i]);
feature_names_c[i] = feature_names[i].data();
}
CHECK_EQ(scores.size(), features.size());
CHECK_EQ(scores.size(), feature_names.size());
*out_length = scores.size();
*out_scores = scores.data();
*out_features = dmlc::BeginPtr(feature_names_c);
API_END();
}
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();

View File

@ -7,6 +7,8 @@
#include <algorithm>
#include <functional>
#include <vector>
#include <memory>
#include <string>
#include "xgboost/logging.h"
#include "xgboost/json.h"
@ -181,5 +183,45 @@ class XGBoostAPIGuard {
RestoreGPUAttribute();
}
};
inline FeatureMap LoadFeatureMap(std::string const& uri) {
FeatureMap feat;
if (uri.size() != 0) {
std::unique_ptr<dmlc::Stream> fs(dmlc::Stream::Create(uri.c_str(), "r"));
dmlc::istream is(fs.get());
feat.LoadText(is);
}
return feat;
}
// FIXME(jiamingy): Use this for model dump.
inline void GenerateFeatureMap(Learner const *learner,
size_t n_features, FeatureMap *out_feature_map) {
auto &feature_map = *out_feature_map;
auto maybe = [&](std::vector<std::string> const &values, size_t i,
std::string const &dft) {
return values.empty() ? dft : values[i];
};
if (feature_map.Size() == 0) {
// Use the feature names and types from booster.
std::vector<std::string> feature_names;
learner->GetFeatureNames(&feature_names);
if (!feature_names.empty()) {
CHECK_EQ(feature_names.size(), n_features) << "Incorrect number of feature names.";
}
std::vector<std::string> feature_types;
learner->GetFeatureTypes(&feature_types);
if (!feature_types.empty()) {
CHECK_EQ(feature_types.size(), n_features) << "Incorrect number of feature types.";
}
for (size_t i = 0; i < n_features; ++i) {
feature_map.PushBack(
i,
maybe(feature_names, i, "f" + std::to_string(i)).data(),
maybe(feature_types, i, "q").data());
}
}
CHECK_EQ(feature_map.Size(), n_features);
}
} // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -9,6 +9,7 @@
#include <dmlc/omp.h>
#include <algorithm>
#include <vector>
#include <map>
#include <memory>
@ -299,6 +300,58 @@ class GBTree : public GradientBooster {
}
}
void FeatureScore(std::string const &importance_type,
std::vector<bst_feature_t> *features,
std::vector<float> *scores) const override {
// Because feature with no importance doesn't appear in the return value so
// we need to set up another pair of vectors to store the values during
// computation.
std::vector<size_t> split_counts(this->model_.learner_model_param->num_feature, 0);
std::vector<float> gain_map(this->model_.learner_model_param->num_feature, 0);
auto add_score = [&](auto fn) {
for (auto const &p_tree : model_.trees) {
p_tree->WalkTree([&](bst_node_t nidx) {
auto const &node = (*p_tree)[nidx];
if (!node.IsLeaf()) {
split_counts[node.SplitIndex()]++;
fn(p_tree, nidx, node.SplitIndex());
}
return true;
});
}
};
if (importance_type == "weight") {
add_score([&](auto const &p_tree, bst_node_t, bst_feature_t split) {
gain_map[split] = split_counts[split];
});
}
if (importance_type == "gain" || importance_type == "total_gain") {
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += p_tree->Stat(nidx).loss_chg;
});
}
if (importance_type == "cover" || importance_type == "total_cover") {
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += p_tree->Stat(nidx).sum_hess;
});
}
if (importance_type == "gain" || importance_type == "cover") {
for (size_t i = 0; i < gain_map.size(); ++i) {
gain_map[i] /= std::max(1.0f, static_cast<float>(split_counts[i]));
}
}
features->clear();
scores->clear();
for (size_t i = 0; i < split_counts.size(); ++i) {
if (split_counts[i] != 0) {
features->push_back(i);
scores->push_back(gain_map[i]);
}
}
}
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
uint32_t layer_begin, uint32_t layer_end) override {

View File

@ -1193,6 +1193,30 @@ class LearnerImpl : public LearnerIO {
*out_preds = &out_predictions.predictions;
}
void CalcFeatureScore(std::string const &importance_type,
std::vector<bst_feature_t> *features,
std::vector<float> *scores) override {
this->Configure();
std::vector<std::string> allowed_importance_type = {
"weight", "total_gain", "total_cover", "gain", "cover"
};
if (std::find(allowed_importance_type.begin(),
allowed_importance_type.end(),
importance_type) == allowed_importance_type.end()) {
std::stringstream ss;
ss << "importance_type mismatch, got: " << importance_type
<< "`, expected one of ";
for (size_t i = 0; i < allowed_importance_type.size(); ++i) {
ss << "`" << allowed_importance_type[i] << "`";
if (i != allowed_importance_type.size() - 1) {
ss << ", ";
}
}
LOG(FATAL) << ss.str();
}
gbm_->FeatureScore(importance_type, features, scores);
}
const std::map<std::string, std::string>& GetConfigurationArguments() const override {
return cfg_;
}

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019-2020 XGBoost contributors
* Copyright 2019-2021 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
@ -410,4 +410,41 @@ TEST(Dart, Slice) {
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
ASSERT_EQ(weights.size(), trees.size());
}
TEST(GBTree, FeatureScore) {
size_t n_samples = 1000, n_features = 10, n_classes = 4;
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
std::unique_ptr<Learner> learner{ Learner::Create({m}) };
learner->SetParam("num_class", std::to_string(n_classes));
learner->Configure();
for (size_t i = 0; i < 2; ++i) {
learner->UpdateOneIter(i, m);
}
std::vector<bst_feature_t> features_weight;
std::vector<float> scores_weight;
learner->CalcFeatureScore("weight", &features_weight, &scores_weight);
ASSERT_EQ(features_weight.size(), scores_weight.size());
ASSERT_LE(features_weight.size(), learner->GetNumFeature());
ASSERT_TRUE(std::is_sorted(features_weight.begin(), features_weight.end()));
auto test_eq = [&learner, &scores_weight](std::string type) {
std::vector<bst_feature_t> features;
std::vector<float> scores;
learner->CalcFeatureScore(type, &features, &scores);
std::vector<bst_feature_t> features_total;
std::vector<float> scores_total;
learner->CalcFeatureScore("total_" + type, &features_total, &scores_total);
for (size_t i = 0; i < scores_weight.size(); ++i) {
ASSERT_LE(RelError(scores_total[i] / scores[i], scores_weight[i]), kRtEps);
}
};
test_eq("gain");
test_eq("cover");
}
} // namespace xgboost

View File

@ -154,6 +154,23 @@ class TestBasic:
dump4j = json.loads(dump4[0])
assert 'gain' in dump4j, "Expected 'gain' to be dumped in JSON."
def test_feature_score(self):
rng = np.random.RandomState(0)
data = rng.randn(100, 2)
target = np.array([0, 1] * 50)
features = ["F0"]
with pytest.raises(ValueError):
xgb.DMatrix(data, label=target, feature_names=features)
params = {"objective": "binary:logistic"}
dm = xgb.DMatrix(data, label=target, feature_names=["F0", "F1"])
booster = xgb.train(params, dm, num_boost_round=1)
# no error since feature names might be assigned before the booster seeing data
# and booster doesn't known about the actual number of features.
booster.feature_names = ["F0"]
with pytest.raises(ValueError):
booster.get_fscore()
def test_load_file_invalid(self):
with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file='incorrect_path')