Refactor histogram building code for gpu_hist (#4528)

This commit is contained in:
Rory Mitchell
2019-06-03 09:50:10 +12:00
committed by GitHub
parent 399fabed49
commit 23a10c8339
2 changed files with 53 additions and 90 deletions

View File

@@ -162,7 +162,7 @@ std::vector<GradientPairPrecise> GetHostHistGpair() {
}
template <typename GradientSumT>
void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
void TestBuildHist(bool use_shared_memory_histograms) {
int const kNRows = 16, kNCols = 8;
TrainParam param;
@@ -170,7 +170,6 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
param.max_leaves = 0;
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
BuildGidx(&shard, kNRows, kNCols);
xgboost::SimpleLCG gen;
@@ -202,7 +201,8 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
thrust::device_pointer_cast(shard.ridx.Current()),
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
builder.Build(&shard, 0);
shard.use_shared_memory_histograms = use_shared_memory_histograms;
shard.BuildHist(0);
DeviceHistogram<GradientSumT> d_hist = shard.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
@@ -224,17 +224,13 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
}
TEST(GpuHist, BuildHistGlobalMem) {
GlobalMemHistBuilder<GradientPairPrecise> double_builder;
TestBuildHist(double_builder);
GlobalMemHistBuilder<GradientPair> float_builder;
TestBuildHist(float_builder);
TestBuildHist<GradientPairPrecise>(false);
TestBuildHist<GradientPair>(false);
}
TEST(GpuHist, BuildHistSharedMem) {
SharedMemHistBuilder<GradientPairPrecise> double_builder;
TestBuildHist(double_builder);
SharedMemHistBuilder<GradientPair> float_builder;
TestBuildHist(float_builder);
TestBuildHist<GradientPairPrecise>(true);
TestBuildHist<GradientPair>(true);
}
common::HistCutMatrix GetHostCutMatrix () {