Implement feature score in GBTree. (#7041)
* Categorical data support. * Eliminate text parsing during feature score computation.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2020 XGBoost contributors
|
||||
* Copyright 2019-2021 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
@@ -410,4 +410,41 @@ TEST(Dart, Slice) {
|
||||
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
|
||||
ASSERT_EQ(weights.size(), trees.size());
|
||||
}
|
||||
|
||||
TEST(GBTree, FeatureScore) {
|
||||
size_t n_samples = 1000, n_features = 10, n_classes = 4;
|
||||
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
|
||||
|
||||
std::unique_ptr<Learner> learner{ Learner::Create({m}) };
|
||||
learner->SetParam("num_class", std::to_string(n_classes));
|
||||
|
||||
learner->Configure();
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
learner->UpdateOneIter(i, m);
|
||||
}
|
||||
|
||||
std::vector<bst_feature_t> features_weight;
|
||||
std::vector<float> scores_weight;
|
||||
learner->CalcFeatureScore("weight", &features_weight, &scores_weight);
|
||||
ASSERT_EQ(features_weight.size(), scores_weight.size());
|
||||
ASSERT_LE(features_weight.size(), learner->GetNumFeature());
|
||||
ASSERT_TRUE(std::is_sorted(features_weight.begin(), features_weight.end()));
|
||||
|
||||
auto test_eq = [&learner, &scores_weight](std::string type) {
|
||||
std::vector<bst_feature_t> features;
|
||||
std::vector<float> scores;
|
||||
learner->CalcFeatureScore(type, &features, &scores);
|
||||
|
||||
std::vector<bst_feature_t> features_total;
|
||||
std::vector<float> scores_total;
|
||||
learner->CalcFeatureScore("total_" + type, &features_total, &scores_total);
|
||||
|
||||
for (size_t i = 0; i < scores_weight.size(); ++i) {
|
||||
ASSERT_LE(RelError(scores_total[i] / scores[i], scores_weight[i]), kRtEps);
|
||||
}
|
||||
};
|
||||
|
||||
test_eq("gain");
|
||||
test_eq("cover");
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user