Support gamma in GPU_Hist. (#4874)
* Just prevent building the tree instead of using an explicit pruner.
This commit is contained in:
parent
a40b72d127
commit
0b89cd1dfa
@ -72,8 +72,9 @@ struct ExpandEntry {
|
|||||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (param.max_depth > 0 && depth == param.max_depth) return false;
|
if (split.loss_chg < param.min_split_loss) { return false; }
|
||||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
|
if (param.max_depth > 0 && depth == param.max_depth) {return false; }
|
||||||
|
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -133,7 +133,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||||
DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols);
|
DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols);
|
||||||
shard.InitHistogram();
|
shard.InitHistogram();
|
||||||
|
|
||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
std::vector<GradientPair> h_gpair(kNRows);
|
std::vector<GradientPair> h_gpair(kNRows);
|
||||||
@ -336,5 +336,66 @@ TEST(GpuHist, TestHistogramIndex) {
|
|||||||
TestHistogramIndexImpl();
|
TestHistogramIndexImpl();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gamma is an alias of min_split_loss
|
||||||
|
int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPair>* gpair) {
|
||||||
|
Args args {
|
||||||
|
{"max_depth", "1"},
|
||||||
|
{"max_leaves", "0"},
|
||||||
|
|
||||||
|
// Disable all other parameters.
|
||||||
|
{"colsample_bynode", "1"},
|
||||||
|
{"colsample_bylevel", "1"},
|
||||||
|
{"colsample_bytree", "1"},
|
||||||
|
{"min_child_weight", "0.01"},
|
||||||
|
{"reg_alpha", "0"},
|
||||||
|
{"reg_lambda", "0"},
|
||||||
|
{"max_delta_step", "0"},
|
||||||
|
|
||||||
|
// test gamma
|
||||||
|
{"gamma", std::to_string(gamma)}
|
||||||
|
};
|
||||||
|
|
||||||
|
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
|
||||||
|
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||||
|
hist_maker.Configure(args, &generic_param);
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
hist_maker.Update(gpair, dmat, {&tree});
|
||||||
|
|
||||||
|
auto n_nodes = tree.NumExtraNodes();
|
||||||
|
return n_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GpuHist, MinSplitLoss) {
|
||||||
|
constexpr size_t kRows = 32;
|
||||||
|
constexpr size_t kCols = 16;
|
||||||
|
constexpr float kSparsity = 0.6;
|
||||||
|
auto dmat = CreateDMatrix(kRows, kCols, kSparsity, 3);
|
||||||
|
|
||||||
|
xgboost::SimpleLCG gen;
|
||||||
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
|
std::vector<GradientPair> h_gpair(kRows);
|
||||||
|
for (auto &gpair : h_gpair) {
|
||||||
|
bst_float grad = dist(&gen);
|
||||||
|
bst_float hess = dist(&gen);
|
||||||
|
gpair = GradientPair(grad, hess);
|
||||||
|
}
|
||||||
|
HostDeviceVector<GradientPair> gpair(h_gpair);
|
||||||
|
|
||||||
|
{
|
||||||
|
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 0.01, &gpair);
|
||||||
|
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
|
||||||
|
// when writing this test, and only used for testing larger gamma (below) does prevent
|
||||||
|
// building tree.
|
||||||
|
ASSERT_EQ(n_nodes, 2);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int32_t n_nodes = TestMinSplitLoss((*dmat).get(), 100.0, &gpair);
|
||||||
|
// No new nodes with gamma == 100.
|
||||||
|
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
|
||||||
|
}
|
||||||
|
delete dmat;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user