Requires setting leaf stat when expanding tree. (#5501)

* Fix GPU Hist feature importance.
This commit is contained in:
Jiaming Yuan
2020-04-10 12:27:03 +08:00
committed by GitHub
parent dc2950fd90
commit 7d52c0b8c2
11 changed files with 179 additions and 50 deletions

View File

@@ -42,13 +42,15 @@ TEST(Updater, Prune) {
pruner->Configure(cfg);
// loss_chg < min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees);
ASSERT_EQ(tree.NumExtraNodes(), 0);
// loss_chg > min_split_loss;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f);
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
pruner->Update(&gpair, p_dmat.get(), trees);
ASSERT_EQ(tree.NumExtraNodes(), 2);
@@ -63,10 +65,12 @@ TEST(Updater, Prune) {
// loss_chg > min_split_loss
tree.ExpandNode(tree[0].LeftChild(),
0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/18.0f, 0.0f);
/*loss_chg=*/18.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(tree[0].RightChild(),
0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/19.0f, 0.0f);
/*loss_chg=*/19.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("max_depth", "1"));
pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees);
@@ -75,7 +79,8 @@ TEST(Updater, Prune) {
tree.ExpandNode(tree[0].LeftChild(),
0, 0.5f, true, 0.3, 0.4, 0.5,
/*loss_chg=*/18.0f, 0.0f);
/*loss_chg=*/18.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
cfg.emplace_back(std::make_pair("min_split_loss", "0"));
pruner->Configure(cfg);
pruner->Update(&gpair, p_dmat.get(), trees);

View File

@@ -34,7 +34,8 @@ TEST(Updater, Refresh) {
std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f);
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild();

View File

@@ -88,13 +88,13 @@ TEST(Tree, Load) {
TEST(Tree, AllocateNode) {
RegTree tree;
tree.ExpandNode(
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.CollapseToLeaf(0, 0);
ASSERT_EQ(tree.NumExtraNodes(), 0);
tree.ExpandNode(
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
ASSERT_EQ(tree.NumExtraNodes(), 2);
auto& nodes = tree.GetNodes();
@@ -107,18 +107,18 @@ RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
return tree;
}
@@ -222,7 +222,8 @@ TEST(Tree, DumpDot) {
TEST(Tree, JsonIO) {
RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
Json j_tree{Object()};
tree.SaveModel(&j_tree);
@@ -246,8 +247,10 @@ TEST(Tree, JsonIO) {
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.SaveModel(&j_tree);
tree.ChangeToLeaf(1, 1.0f);

View File

@@ -0,0 +1,59 @@
#include <xgboost/tree_updater.h>
#include <xgboost/tree_model.h>
#include <gtest/gtest.h>
#include "../helpers.h"
namespace xgboost {
class UpdaterTreeStatTest : public ::testing::Test {
protected:
std::shared_ptr<DMatrix> p_dmat_;
HostDeviceVector<GradientPair> gpairs_;
size_t constexpr static kRows = 10;
size_t constexpr static kCols = 10;
protected:
void SetUp() override {
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix(true);
auto g = GenerateRandomGradients(kRows);
gpairs_.Resize(kRows);
gpairs_.Copy(g);
}
void RunTest(std::string updater) {
auto tparam = CreateEmptyGenericParam(0);
auto up = std::unique_ptr<TreeUpdater>{
TreeUpdater::Create(updater, &tparam)};
up->Configure(Args{});
RegTree tree;
tree.param.num_feature = kCols;
up->Update(&gpairs_, p_dmat_.get(), {&tree});
tree.WalkTree([&tree](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
// 1.0 is the default `min_child_weight`.
CHECK_GE(tree.Stat(nidx).sum_hess, 1.0);
}
return true;
});
}
};
#if defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterTreeStatTest, GPUHist) {
this->RunTest("grow_gpu_hist");
}
#endif // defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterTreeStatTest, Hist) {
this->RunTest("grow_quantile_histmaker");
}
TEST_F(UpdaterTreeStatTest, Exact) {
this->RunTest("grow_colmaker");
}
TEST_F(UpdaterTreeStatTest, Approx) {
this->RunTest("grow_histmaker");
}
} // namespace xgboost