Support slicing tree model (#6302)

This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone.

* Implement the save_best option in early stopping.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-11-03 02:27:39 -05:00
committed by GitHub
parent 29745c6df2
commit 2cc9662005
19 changed files with 550 additions and 37 deletions

View File

@@ -154,9 +154,9 @@ TEST(GBTree, JsonIO) {
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
auto const& gbtree_model = model["model"]["model"];
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1ul);
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1ul);
auto j_train_param = model["config"]["gbtree_train_param"];
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
@@ -194,7 +194,7 @@ TEST(Dart, JsonIO) {
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0ul);
}
TEST(Dart, Prediction) {
@@ -230,4 +230,122 @@ TEST(Dart, Prediction) {
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
}
}
std::pair<Json, Json> TestModelSlice(std::string booster) {
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses);
int32_t kIters = 10;
std::unique_ptr<Learner> learner {
Learner::Create({m})
};
learner->SetParams(Args{{"booster", booster},
{"tree_method", "hist"},
{"num_parallel_tree", std::to_string(kForest)},
{"num_class", std::to_string(kClasses)},
{"subsample", "0.5"},
{"max_depth", "2"}});
for (auto i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, m);
}
Json model{Object()};
Json config{Object()};
learner->SaveModel(&model);
learner->SaveConfig(&config);
bool out_of_bound = false;
size_t constexpr kSliceStart = 2, kSliceEnd = 8, kStep = 3;
std::unique_ptr<Learner> sliced {learner->Slice(kSliceStart, kSliceEnd, kStep, &out_of_bound)};
Json sliced_model{Object()};
sliced->SaveModel(&sliced_model);
auto get_shape = [&](Json const& model) {
if (booster == "gbtree") {
return get<Object const>(model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]);
} else {
return get<Object const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["gbtree_model_param"]);
}
};
auto const& model_shape = get_shape(sliced_model);
CHECK_EQ(get<String const>(model_shape.at("num_trees")), std::to_string(2 * kClasses * kForest));
Json sliced_config {Object()};
sliced->SaveConfig(&sliced_config);
CHECK_EQ(sliced_config, config);
auto get_trees = [&](Json const& model) {
if (booster == "gbtree") {
return get<Array const>(model["learner"]["gradient_booster"]["model"]["trees"]);
} else {
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
}
};
auto get_info = [&](Json const& model) {
if (booster == "gbtree") {
return get<Array const>(model["learner"]["gradient_booster"]["model"]["tree_info"]);
} else {
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["tree_info"]);
}
};
auto const &sliced_trees = get_trees(sliced_model);
CHECK_EQ(sliced_trees.size(), 2 * kClasses * kForest);
auto constexpr kLayerSize = kClasses * kForest;
auto const &sliced_info = get_info(sliced_model);
for (size_t layer = 0; layer < 2; ++layer) {
for (size_t j = 0; j < kClasses; ++j) {
for (size_t k = 0; k < kForest; ++k) {
auto idx = layer * kLayerSize + j * kForest + k;
auto const &group = get<Integer const>(sliced_info.at(idx));
CHECK_EQ(static_cast<size_t>(group), j);
}
}
}
auto const& trees = get_trees(model);
// Sliced layers are [2, 5]
auto begin = kLayerSize * kSliceStart;
auto end = begin + kLayerSize;
auto j = 0;
for (size_t i = begin; i < end; ++i) {
Json tree = trees[i];
tree["id"] = Integer(0); // id is different, we set it to 0 to allow comparison.
auto sliced_tree = sliced_trees[j];
sliced_tree["id"] = Integer(0);
CHECK_EQ(tree, sliced_tree);
j++;
}
begin = kLayerSize * (kSliceStart + kStep);
end = begin + kLayerSize;
for (size_t i = begin; i < end; ++i) {
Json tree = trees[i];
tree["id"] = Integer(0);
auto sliced_tree = sliced_trees[j];
sliced_tree["id"] = Integer(0);
CHECK_EQ(tree, sliced_tree);
j++;
}
return std::make_pair(model, sliced_model);
}
TEST(GBTree, Slice) {
TestModelSlice("gbtree");
}
TEST(Dart, Slice) {
Json model, sliced_model;
std::tie(model, sliced_model) = TestModelSlice("dart");
auto const& weights = get<Array const>(model["learner"]["gradient_booster"]["weight_drop"]);
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
ASSERT_EQ(weights.size(), trees.size());
}
} // namespace xgboost

View File

@@ -118,7 +118,7 @@ TEST(Learner, Configuration) {
// eval_metric is not part of configuration
auto attr_names = learner->GetConfigurationArguments();
ASSERT_EQ(attr_names.size(), 1);
ASSERT_EQ(attr_names.size(), 1ul);
ASSERT_EQ(attr_names.find(emetric), attr_names.cend());
ASSERT_EQ(attr_names.at("foo"), "bar");
}
@@ -127,7 +127,7 @@ TEST(Learner, Configuration) {
std::unique_ptr<Learner> learner { Learner::Create({nullptr}) };
learner->SetParams({{"foo", "bar"}, {emetric, "auc"}, {emetric, "entropy"}, {emetric, "KL"}});
auto attr_names = learner->GetConfigurationArguments();
ASSERT_EQ(attr_names.size(), 1);
ASSERT_EQ(attr_names.size(), 1ul);
ASSERT_EQ(attr_names.at("foo"), "bar");
}
}
@@ -181,7 +181,7 @@ TEST(Learner, JsonModelIO) {
learner->SaveModel(&new_in);
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1ul);
ASSERT_EQ(out, new_in);
}
}
@@ -333,5 +333,4 @@ TEST(Learner, Seed) {
ASSERT_EQ(std::to_string(seed),
get<String>(config["learner"]["generic_param"]["seed"]));
}
} // namespace xgboost