Refactor histogram building code for gpu_hist (#4528)
This commit is contained in:
@@ -162,7 +162,7 @@ std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
int const kNRows = 16, kNCols = 8;
|
||||
|
||||
TrainParam param;
|
||||
@@ -170,7 +170,6 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
param.max_leaves = 0;
|
||||
|
||||
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
|
||||
|
||||
BuildGidx(&shard, kNRows, kNCols);
|
||||
|
||||
xgboost::SimpleLCG gen;
|
||||
@@ -202,7 +201,8 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
thrust::device_pointer_cast(shard.ridx.Current()),
|
||||
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
|
||||
|
||||
builder.Build(&shard, 0);
|
||||
shard.use_shared_memory_histograms = use_shared_memory_histograms;
|
||||
shard.BuildHist(0);
|
||||
DeviceHistogram<GradientSumT> d_hist = shard.hist;
|
||||
|
||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||
@@ -224,17 +224,13 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, BuildHistGlobalMem) {
|
||||
GlobalMemHistBuilder<GradientPairPrecise> double_builder;
|
||||
TestBuildHist(double_builder);
|
||||
GlobalMemHistBuilder<GradientPair> float_builder;
|
||||
TestBuildHist(float_builder);
|
||||
TestBuildHist<GradientPairPrecise>(false);
|
||||
TestBuildHist<GradientPair>(false);
|
||||
}
|
||||
|
||||
TEST(GpuHist, BuildHistSharedMem) {
|
||||
SharedMemHistBuilder<GradientPairPrecise> double_builder;
|
||||
TestBuildHist(double_builder);
|
||||
SharedMemHistBuilder<GradientPair> float_builder;
|
||||
TestBuildHist(float_builder);
|
||||
TestBuildHist<GradientPairPrecise>(true);
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
common::HistCutMatrix GetHostCutMatrix () {
|
||||
|
||||
Reference in New Issue
Block a user