Implement feature score in GBTree. (#7041)
* Categorical data support. * Eliminate text parsing during feature score computation.
This commit is contained in:
@@ -1098,5 +1098,47 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle,
|
||||
const char *json_config,
|
||||
xgboost::bst_ulong* out_length,
|
||||
const char ***out_features,
|
||||
float **out_scores) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
auto config = Json::Load(StringView{json_config});
|
||||
auto importance = get<String const>(config["importance_type"]);
|
||||
std::string feature_map_uri;
|
||||
if (!IsA<Null>(config["feature_map"])) {
|
||||
feature_map_uri = get<String const>(config["feature_map"]);
|
||||
}
|
||||
FeatureMap feature_map = LoadFeatureMap(feature_map_uri);
|
||||
|
||||
auto& scores = learner->GetThreadLocal().ret_vec_float;
|
||||
std::vector<bst_feature_t> features;
|
||||
learner->CalcFeatureScore(importance, &features, &scores);
|
||||
|
||||
auto n_features = learner->GetNumFeature();
|
||||
GenerateFeatureMap(learner, n_features, &feature_map);
|
||||
CHECK_LE(features.size(), n_features);
|
||||
|
||||
auto& feature_names = learner->GetThreadLocal().ret_vec_str;
|
||||
feature_names.resize(features.size());
|
||||
auto& feature_names_c = learner->GetThreadLocal().ret_vec_charp;
|
||||
feature_names_c.resize(features.size());
|
||||
|
||||
for (bst_feature_t i = 0; i < features.size(); ++i) {
|
||||
feature_names[i] = feature_map.Name(features[i]);
|
||||
feature_names_c[i] = feature_names[i].data();
|
||||
}
|
||||
|
||||
CHECK_EQ(scores.size(), features.size());
|
||||
CHECK_EQ(scores.size(), feature_names.size());
|
||||
*out_length = scores.size();
|
||||
*out_scores = scores.data();
|
||||
*out_features = dmlc::BeginPtr(feature_names_c);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
Reference in New Issue
Block a user