Support gamma in GPU_Hist. (#4874)

* Just prevent building the tree instead of using an explicit pruner.
This commit is contained in:
Jiaming Yuan 2019-09-24 10:16:08 +08:00 committed by GitHub
parent a40b72d127
commit 0b89cd1dfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 3 deletions

View File

@ -72,8 +72,9 @@ struct ExpandEntry {
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false; return false;
} }
if (param.max_depth > 0 && depth == param.max_depth) return false; if (split.loss_chg < param.min_split_loss) { return false; }
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; if (param.max_depth > 0 && depth == param.max_depth) {return false; }
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
return true; return true;
} }

View File

@ -133,7 +133,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols); DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols);
shard.InitHistogram(); shard.InitHistogram();
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f); xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(kNRows); std::vector<GradientPair> h_gpair(kNRows);
@ -336,5 +336,66 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl(); TestHistogramIndexImpl();
} }
// gamma is an alias of min_split_loss
int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPair>* gpair) {
Args args {
{"max_depth", "1"},
{"max_leaves", "0"},
// Disable all other parameters.
{"colsample_bynode", "1"},
{"colsample_bylevel", "1"},
{"colsample_bytree", "1"},
{"min_child_weight", "0.01"},
{"reg_alpha", "0"},
{"reg_lambda", "0"},
{"max_delta_step", "0"},
// test gamma
{"gamma", std::to_string(gamma)}
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param);
RegTree tree;
hist_maker.Update(gpair, dmat, {&tree});
auto n_nodes = tree.NumExtraNodes();
return n_nodes;
}
TEST(GpuHist, MinSplitLoss) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 16;
constexpr float kSparsity = 0.6;
auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3);
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
std::vector<GradientPair> h_gpair(kRows);
for (auto &gpair : h_gpair) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
gpair = GradientPair(grad, hess);
}
HostDeviceVector<GradientPair> gpair(h_gpair);
{
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 100.0, &gpair);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
delete dmat;
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost