Implement feature score in GBTree. (#7041)
* Categorical data support. * Eliminate text parsing during feature score computation.
This commit is contained in:
parent
dcd84b3979
commit
7dd29ffd47
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2015~2020 by Contributors
|
||||
* Copyright (c) 2015~2021 by Contributors
|
||||
* \file c_api.h
|
||||
* \author Tianqi Chen
|
||||
* \brief C API of XGBoost, used for interfacing to other languages.
|
||||
@ -1193,4 +1193,28 @@ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
bst_ulong *len,
|
||||
const char ***out_features);
|
||||
|
||||
/*!
|
||||
* \brief Calculate feature scores for tree models.
|
||||
*
|
||||
* \param handle An instance of Booster
|
||||
* \param json_config Parameters for computing scores. Accepted JSON keys are:
|
||||
* - importance_type: A JSON string with following possible values:
|
||||
* * 'weight': the number of times a feature is used to split the data across all trees.
|
||||
* * 'gain': the average gain across all splits the feature is used in.
|
||||
* * 'cover': the average coverage across all splits the feature is used in.
|
||||
* * 'total_gain': the total gain across all splits the feature is used in.
|
||||
* * 'total_cover': the total coverage across all splits the feature is used in.
|
||||
* - feature_map: An optional JSON string with URI or path to the feature map file.
|
||||
*
|
||||
* \param out_length Length of output arrays.
|
||||
* \param out_features An array of string as feature names, ordered the same as output scores.
|
||||
* \param out_scores An array of floating point as feature scores.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
|
||||
bst_ulong *out_length,
|
||||
const char ***out_features,
|
||||
float **out_scores);
|
||||
#endif // XGBOOST_C_API_H_
|
||||
|
||||
@ -181,6 +181,12 @@ class GradientBooster : public Model, public Configurable {
|
||||
virtual std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const = 0;
|
||||
|
||||
virtual void FeatureScore(std::string const &importance_type,
|
||||
std::vector<bst_feature_t> *features,
|
||||
std::vector<float> *scores) const {
|
||||
LOG(FATAL) << "`feature_score` is not implemented for current booster.";
|
||||
}
|
||||
/*!
|
||||
* \brief Whether the current booster uses GPU.
|
||||
*/
|
||||
|
||||
@ -152,6 +152,13 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Calculate feature score. See doc in C API for outputs.
|
||||
*/
|
||||
virtual void CalcFeatureScore(std::string const &importance_type,
|
||||
std::vector<bst_feature_t> *features,
|
||||
std::vector<float> *scores) = 0;
|
||||
|
||||
/*
|
||||
* \brief Get number of boosted rounds from gradient booster.
|
||||
*/
|
||||
|
||||
@ -2191,7 +2191,9 @@ class Booster(object):
|
||||
|
||||
return self.get_score(fmap, importance_type='weight')
|
||||
|
||||
def get_score(self, fmap='', importance_type='weight'):
|
||||
def get_score(
|
||||
self, fmap: os.PathLike = '', importance_type: str = 'weight'
|
||||
) -> Dict[str, float]:
|
||||
"""Get feature importance of each feature.
|
||||
Importance type can be defined as:
|
||||
|
||||
@ -2203,9 +2205,9 @@ class Booster(object):
|
||||
|
||||
.. note:: Feature importance is defined only for tree boosters
|
||||
|
||||
Feature importance is only defined when the decision tree model is chosen as base
|
||||
learner (`booster=gbtree`). It is not defined for other base learner types, such
|
||||
as linear learners (`booster=gblinear`).
|
||||
Feature importance is only defined when the decision tree model is chosen as
|
||||
base learner (`booster=gbtree` or `booster=dart`). It is not defined for other
|
||||
base learner types, such as linear learners (`booster=gblinear`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -2213,86 +2215,33 @@ class Booster(object):
|
||||
The name of feature map file.
|
||||
importance_type: str, default 'weight'
|
||||
One of the importance types defined above.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A map between feature names and their scores.
|
||||
"""
|
||||
fmap = os.fspath(os.path.expanduser(fmap))
|
||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||
raise ValueError('Feature importance is not defined for Booster type {}'
|
||||
.format(self.booster))
|
||||
|
||||
allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover']
|
||||
if importance_type not in allowed_importance_types:
|
||||
msg = ("importance_type mismatch, got '{}', expected one of " +
|
||||
repr(allowed_importance_types))
|
||||
raise ValueError(msg.format(importance_type))
|
||||
|
||||
# if it's weight, then omap stores the number of missing values
|
||||
if importance_type == 'weight':
|
||||
# do a simpler tree dump to save time
|
||||
trees = self.get_dump(fmap, with_stats=False)
|
||||
fmap = {}
|
||||
for tree in trees:
|
||||
for line in tree.split('\n'):
|
||||
# look for the opening square bracket
|
||||
arr = line.split('[')
|
||||
# if no opening bracket (leaf node), ignore this line
|
||||
if len(arr) == 1:
|
||||
continue
|
||||
|
||||
# extract feature name from string between []
|
||||
fid = arr[1].split(']')[0].split('<')[0]
|
||||
|
||||
if fid not in fmap:
|
||||
# if the feature hasn't been seen yet
|
||||
fmap[fid] = 1
|
||||
else:
|
||||
fmap[fid] += 1
|
||||
|
||||
return fmap
|
||||
|
||||
average_over_splits = True
|
||||
if importance_type == 'total_gain':
|
||||
importance_type = 'gain'
|
||||
average_over_splits = False
|
||||
elif importance_type == 'total_cover':
|
||||
importance_type = 'cover'
|
||||
average_over_splits = False
|
||||
|
||||
trees = self.get_dump(fmap, with_stats=True)
|
||||
|
||||
importance_type += '='
|
||||
fmap = {}
|
||||
gmap = {}
|
||||
for tree in trees:
|
||||
for line in tree.split('\n'):
|
||||
# look for the opening square bracket
|
||||
arr = line.split('[')
|
||||
# if no opening bracket (leaf node), ignore this line
|
||||
if len(arr) == 1:
|
||||
continue
|
||||
|
||||
# look for the closing bracket, extract only info within that bracket
|
||||
fid = arr[1].split(']')
|
||||
|
||||
# extract gain or cover from string after closing bracket
|
||||
g = float(fid[1].split(importance_type)[1].split(',')[0])
|
||||
|
||||
# extract feature name from string before closing bracket
|
||||
fid = fid[0].split('<')[0]
|
||||
|
||||
if fid not in fmap:
|
||||
# if the feature hasn't been seen yet
|
||||
fmap[fid] = 1
|
||||
gmap[fid] = g
|
||||
else:
|
||||
fmap[fid] += 1
|
||||
gmap[fid] += g
|
||||
|
||||
# calculate average value (gain/cover) for each feature
|
||||
if average_over_splits:
|
||||
for fid in gmap:
|
||||
gmap[fid] = gmap[fid] / fmap[fid]
|
||||
|
||||
return gmap
|
||||
args = from_pystr_to_cstr(
|
||||
json.dumps({"importance_type": importance_type, "feature_map": fmap})
|
||||
)
|
||||
features = ctypes.POINTER(ctypes.c_char_p)()
|
||||
scores = ctypes.POINTER(ctypes.c_float)()
|
||||
length = c_bst_ulong()
|
||||
_check_call(
|
||||
_LIB.XGBoosterFeatureScore(
|
||||
self.handle,
|
||||
args,
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(features),
|
||||
ctypes.byref(scores)
|
||||
)
|
||||
)
|
||||
features_arr = from_cstr_to_pystr(features, length)
|
||||
scores_arr = ctypes2numpy(scores, length.value, np.float32)
|
||||
results = {}
|
||||
for feat, score in zip(features_arr, scores_arr):
|
||||
results[feat] = score
|
||||
return results
|
||||
|
||||
def trees_to_dataframe(self, fmap=''):
|
||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||
|
||||
@ -1098,5 +1098,47 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle,
|
||||
const char *json_config,
|
||||
xgboost::bst_ulong* out_length,
|
||||
const char ***out_features,
|
||||
float **out_scores) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
auto config = Json::Load(StringView{json_config});
|
||||
auto importance = get<String const>(config["importance_type"]);
|
||||
std::string feature_map_uri;
|
||||
if (!IsA<Null>(config["feature_map"])) {
|
||||
feature_map_uri = get<String const>(config["feature_map"]);
|
||||
}
|
||||
FeatureMap feature_map = LoadFeatureMap(feature_map_uri);
|
||||
|
||||
auto& scores = learner->GetThreadLocal().ret_vec_float;
|
||||
std::vector<bst_feature_t> features;
|
||||
learner->CalcFeatureScore(importance, &features, &scores);
|
||||
|
||||
auto n_features = learner->GetNumFeature();
|
||||
GenerateFeatureMap(learner, n_features, &feature_map);
|
||||
CHECK_LE(features.size(), n_features);
|
||||
|
||||
auto& feature_names = learner->GetThreadLocal().ret_vec_str;
|
||||
feature_names.resize(features.size());
|
||||
auto& feature_names_c = learner->GetThreadLocal().ret_vec_charp;
|
||||
feature_names_c.resize(features.size());
|
||||
|
||||
for (bst_feature_t i = 0; i < features.size(); ++i) {
|
||||
feature_names[i] = feature_map.Name(features[i]);
|
||||
feature_names_c[i] = feature_names[i].data();
|
||||
}
|
||||
|
||||
CHECK_EQ(scores.size(), features.size());
|
||||
CHECK_EQ(scores.size(), feature_names.size());
|
||||
*out_length = scores.size();
|
||||
*out_scores = scores.data();
|
||||
*out_features = dmlc::BeginPtr(feature_names_c);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
@ -181,5 +183,45 @@ class XGBoostAPIGuard {
|
||||
RestoreGPUAttribute();
|
||||
}
|
||||
};
|
||||
|
||||
inline FeatureMap LoadFeatureMap(std::string const& uri) {
|
||||
FeatureMap feat;
|
||||
if (uri.size() != 0) {
|
||||
std::unique_ptr<dmlc::Stream> fs(dmlc::Stream::Create(uri.c_str(), "r"));
|
||||
dmlc::istream is(fs.get());
|
||||
feat.LoadText(is);
|
||||
}
|
||||
return feat;
|
||||
}
|
||||
|
||||
// FIXME(jiamingy): Use this for model dump.
|
||||
inline void GenerateFeatureMap(Learner const *learner,
|
||||
size_t n_features, FeatureMap *out_feature_map) {
|
||||
auto &feature_map = *out_feature_map;
|
||||
auto maybe = [&](std::vector<std::string> const &values, size_t i,
|
||||
std::string const &dft) {
|
||||
return values.empty() ? dft : values[i];
|
||||
};
|
||||
if (feature_map.Size() == 0) {
|
||||
// Use the feature names and types from booster.
|
||||
std::vector<std::string> feature_names;
|
||||
learner->GetFeatureNames(&feature_names);
|
||||
if (!feature_names.empty()) {
|
||||
CHECK_EQ(feature_names.size(), n_features) << "Incorrect number of feature names.";
|
||||
}
|
||||
std::vector<std::string> feature_types;
|
||||
learner->GetFeatureTypes(&feature_types);
|
||||
if (!feature_types.empty()) {
|
||||
CHECK_EQ(feature_types.size(), n_features) << "Incorrect number of feature types.";
|
||||
}
|
||||
for (size_t i = 0; i < n_features; ++i) {
|
||||
feature_map.PushBack(
|
||||
i,
|
||||
maybe(feature_names, i, "f" + std::to_string(i)).data(),
|
||||
maybe(feature_types, i, "q").data());
|
||||
}
|
||||
}
|
||||
CHECK_EQ(feature_map.Size(), n_features);
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@ -299,6 +300,58 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
|
||||
void FeatureScore(std::string const &importance_type,
|
||||
std::vector<bst_feature_t> *features,
|
||||
std::vector<float> *scores) const override {
|
||||
// Because feature with no importance doesn't appear in the return value so
|
||||
// we need to set up another pair of vectors to store the values during
|
||||
// computation.
|
||||
std::vector<size_t> split_counts(this->model_.learner_model_param->num_feature, 0);
|
||||
std::vector<float> gain_map(this->model_.learner_model_param->num_feature, 0);
|
||||
auto add_score = [&](auto fn) {
|
||||
for (auto const &p_tree : model_.trees) {
|
||||
p_tree->WalkTree([&](bst_node_t nidx) {
|
||||
auto const &node = (*p_tree)[nidx];
|
||||
if (!node.IsLeaf()) {
|
||||
split_counts[node.SplitIndex()]++;
|
||||
fn(p_tree, nidx, node.SplitIndex());
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (importance_type == "weight") {
|
||||
add_score([&](auto const &p_tree, bst_node_t, bst_feature_t split) {
|
||||
gain_map[split] = split_counts[split];
|
||||
});
|
||||
}
|
||||
if (importance_type == "gain" || importance_type == "total_gain") {
|
||||
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
|
||||
gain_map[split] += p_tree->Stat(nidx).loss_chg;
|
||||
});
|
||||
}
|
||||
if (importance_type == "cover" || importance_type == "total_cover") {
|
||||
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
|
||||
gain_map[split] += p_tree->Stat(nidx).sum_hess;
|
||||
});
|
||||
}
|
||||
if (importance_type == "gain" || importance_type == "cover") {
|
||||
for (size_t i = 0; i < gain_map.size(); ++i) {
|
||||
gain_map[i] /= std::max(1.0f, static_cast<float>(split_counts[i]));
|
||||
}
|
||||
}
|
||||
|
||||
features->clear();
|
||||
scores->clear();
|
||||
for (size_t i = 0; i < split_counts.size(); ++i) {
|
||||
if (split_counts[i] != 0) {
|
||||
features->push_back(i);
|
||||
scores->push_back(gain_map[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) override {
|
||||
|
||||
@ -1193,6 +1193,30 @@ class LearnerImpl : public LearnerIO {
|
||||
*out_preds = &out_predictions.predictions;
|
||||
}
|
||||
|
||||
void CalcFeatureScore(std::string const &importance_type,
|
||||
std::vector<bst_feature_t> *features,
|
||||
std::vector<float> *scores) override {
|
||||
this->Configure();
|
||||
std::vector<std::string> allowed_importance_type = {
|
||||
"weight", "total_gain", "total_cover", "gain", "cover"
|
||||
};
|
||||
if (std::find(allowed_importance_type.begin(),
|
||||
allowed_importance_type.end(),
|
||||
importance_type) == allowed_importance_type.end()) {
|
||||
std::stringstream ss;
|
||||
ss << "importance_type mismatch, got: " << importance_type
|
||||
<< "`, expected one of ";
|
||||
for (size_t i = 0; i < allowed_importance_type.size(); ++i) {
|
||||
ss << "`" << allowed_importance_type[i] << "`";
|
||||
if (i != allowed_importance_type.size() - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << ss.str();
|
||||
}
|
||||
gbm_->FeatureScore(importance_type, features, scores);
|
||||
}
|
||||
|
||||
const std::map<std::string, std::string>& GetConfigurationArguments() const override {
|
||||
return cfg_;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2020 XGBoost contributors
|
||||
* Copyright 2019-2021 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
@ -410,4 +410,41 @@ TEST(Dart, Slice) {
|
||||
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
|
||||
ASSERT_EQ(weights.size(), trees.size());
|
||||
}
|
||||
|
||||
TEST(GBTree, FeatureScore) {
|
||||
size_t n_samples = 1000, n_features = 10, n_classes = 4;
|
||||
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
|
||||
|
||||
std::unique_ptr<Learner> learner{ Learner::Create({m}) };
|
||||
learner->SetParam("num_class", std::to_string(n_classes));
|
||||
|
||||
learner->Configure();
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
learner->UpdateOneIter(i, m);
|
||||
}
|
||||
|
||||
std::vector<bst_feature_t> features_weight;
|
||||
std::vector<float> scores_weight;
|
||||
learner->CalcFeatureScore("weight", &features_weight, &scores_weight);
|
||||
ASSERT_EQ(features_weight.size(), scores_weight.size());
|
||||
ASSERT_LE(features_weight.size(), learner->GetNumFeature());
|
||||
ASSERT_TRUE(std::is_sorted(features_weight.begin(), features_weight.end()));
|
||||
|
||||
auto test_eq = [&learner, &scores_weight](std::string type) {
|
||||
std::vector<bst_feature_t> features;
|
||||
std::vector<float> scores;
|
||||
learner->CalcFeatureScore(type, &features, &scores);
|
||||
|
||||
std::vector<bst_feature_t> features_total;
|
||||
std::vector<float> scores_total;
|
||||
learner->CalcFeatureScore("total_" + type, &features_total, &scores_total);
|
||||
|
||||
for (size_t i = 0; i < scores_weight.size(); ++i) {
|
||||
ASSERT_LE(RelError(scores_total[i] / scores[i], scores_weight[i]), kRtEps);
|
||||
}
|
||||
};
|
||||
|
||||
test_eq("gain");
|
||||
test_eq("cover");
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -154,6 +154,23 @@ class TestBasic:
|
||||
dump4j = json.loads(dump4[0])
|
||||
assert 'gain' in dump4j, "Expected 'gain' to be dumped in JSON."
|
||||
|
||||
def test_feature_score(self):
|
||||
rng = np.random.RandomState(0)
|
||||
data = rng.randn(100, 2)
|
||||
target = np.array([0, 1] * 50)
|
||||
features = ["F0"]
|
||||
with pytest.raises(ValueError):
|
||||
xgb.DMatrix(data, label=target, feature_names=features)
|
||||
|
||||
params = {"objective": "binary:logistic"}
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=["F0", "F1"])
|
||||
booster = xgb.train(params, dm, num_boost_round=1)
|
||||
# no error since feature names might be assigned before the booster seeing data
|
||||
# and booster doesn't known about the actual number of features.
|
||||
booster.feature_names = ["F0"]
|
||||
with pytest.raises(ValueError):
|
||||
booster.get_fscore()
|
||||
|
||||
def test_load_file_invalid(self):
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
xgb.Booster(model_file='incorrect_path')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user