Support gamma in GPU_Hist. (#4874)

* Just prevent building the tree instead of using an explicit pruner.
This commit is contained in:
Jiaming Yuan 2019-09-24 10:16:08 +08:00 committed by GitHub
parent a40b72d127
commit 0b89cd1dfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 3 deletions

View File

@ -72,8 +72,9 @@ struct ExpandEntry {
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (param.max_depth > 0 && depth == param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
if (split.loss_chg < param.min_split_loss) { return false; }
if (param.max_depth > 0 && depth == param.max_depth) {return false; }
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
return true;
}

View File

@ -336,5 +336,66 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl();
}
// gamma is an alias of min_split_loss
int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPair>* gpair) {
Args args {
{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"},
// test gamma
{"gamma", std::to_string(gamma)}
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param);
RegTree tree;
hist_maker.Update(gpair, dmat, {&tree});
auto n_nodes = tree.NumExtraNodes();
return n_nodes;
}
TEST(GpuHist, MinSplitLoss) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3);
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(kRows);
for (auto &gpair : h_gpair) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
gpair = GradientPair(grad, hess);
}
HostDeviceVector<GradientPair> gpair(h_gpair);
{
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 100.0, &gpair);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
delete dmat;
}
} // namespace tree
} // namespace xgboost