Refactor histogram building code for gpu_hist (#4528)

This commit is contained in:
Rory Mitchell 2019-06-03 09:50:10 +12:00 committed by GitHub
parent 399fabed49
commit 23a10c8339
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 90 deletions

View File

@ -515,28 +515,38 @@ __global__ void CompressBinEllpackKernel(
}
template <typename GradientSumT>
__global__ void SharedMemHistKernel(ELLPackMatrix matrix, const bst_uint* d_ridx,
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
const bst_uint* d_ridx,
GradientSumT* d_node_hist,
const GradientPair* d_gpair,
size_t segment_begin, size_t n_elements) {
size_t segment_begin, size_t n_elements,
bool use_shared_memory_histograms) {
extern __shared__ char smem[];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
for (auto i :
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
smem_arr[i] = GradientSumT();
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT());
__syncthreads();
}
__syncthreads();
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
if (gidx != matrix.null_gidx_value) {
AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]);
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
AtomicAddGpair(atomic_add_ptr + gidx, d_gpair[ridx]);
}
}
__syncthreads();
for (auto i :
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i :
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
}
}
}
@ -620,16 +630,6 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
#endif
}
template <typename GradientSumT>
struct DeviceShard;
template <typename GradientSumT>
struct GPUHistBuilderBase {
public:
virtual void Build(DeviceShard<GradientSumT>* shard, int idx) = 0;
virtual ~GPUHistBuilderBase() = default;
};
// Manage memory for a single GPU
template <typename GradientSumT>
struct DeviceShard {
@ -679,6 +679,7 @@ struct DeviceShard {
TrainParam param;
bool prediction_cache_initialised;
bool use_shared_memory_histograms {false};
dh::CubMemory temp_memory;
dh::PinnedMemory pinned_memory;
@ -689,8 +690,6 @@ struct DeviceShard {
std::vector<ValueConstraint> node_value_constraints;
common::ColumnSampler column_sampler;
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
@ -870,7 +869,27 @@ struct DeviceShard {
void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
hist_builder->Build(this, nidx);
auto segment = ridx_segments[nidx];
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = ridx.Current();
auto d_gpair = gpair.data();
auto n_elements = segment.Size() * ellpack_matrix.row_stride;
const size_t smem_size =
use_shared_memory_histograms
? sizeof(GradientSumT) * ellpack_matrix.BinCount()
: 0;
const int items_per_thread = 8;
const int block_threads = 256;
const int grid_size = static_cast<int>(
dh::DivRoundUp(n_elements, items_per_thread * block_threads));
if (grid_size <= 0) {
return;
}
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin,
n_elements, use_shared_memory_histograms);
}
void SubtractionTrick(int nidx_parent, int nidx_histogram,
@ -1084,7 +1103,6 @@ struct DeviceShard {
this->AllReduceHist(subtraction_trick_nidx, reducer);
}
}
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
@ -1209,55 +1227,6 @@ struct DeviceShard {
}
};
template <typename GradientSumT>
struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
auto segment = shard->ridx_segments[nidx];
auto segment_begin = segment.begin;
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
auto d_ridx = shard->ridx.Current();
auto d_gpair = shard->gpair.data();
auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
const size_t smem_size = sizeof(GradientSumT) * shard->ellpack_matrix.BinCount();
const int items_per_thread = 8;
const int block_threads = 256;
const int grid_size =
static_cast<int>(dh::DivRoundUp(n_elements,
items_per_thread * block_threads));
if (grid_size <= 0) {
return;
}
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
shard->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair,
segment_begin, n_elements);
}
};
template <typename GradientSumT>
struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
Segment segment = shard->ridx_segments[nidx];
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
bst_uint* d_ridx = shard->ridx.Current();
GradientPair* d_gpair = shard->gpair.data();
size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
auto d_matrix = shard->ellpack_matrix;
dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) {
int ridx = d_ridx[(idx / d_matrix.row_stride) + segment.begin];
// lookup the index (bin) of histogram.
int gidx = d_matrix.gidx_iter[ridx * d_matrix.row_stride + idx % d_matrix.row_stride];
if (gidx != d_matrix.null_gidx_value) {
AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]);
}
});
}
};
template <typename GradientSumT>
inline void DeviceShard<GradientSumT>::InitCompressedData(
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
@ -1317,9 +1286,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
auto max_smem = dh::MaxSharedMemory(device_id);
if (histogram_size <= max_smem) {
hist_builder.reset(new SharedMemHistBuilder<GradientSumT>);
} else {
hist_builder.reset(new GlobalMemHistBuilder<GradientSumT>);
use_shared_memory_histograms = true;
}
// Init histogram

View File

@ -162,7 +162,7 @@ std::vector<GradientPairPrecise> GetHostHistGpair() {
}
template <typename GradientSumT>
void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
void TestBuildHist(bool use_shared_memory_histograms) {
int const kNRows = 16, kNCols = 8;
TrainParam param;
@ -170,7 +170,6 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
param.max_leaves = 0;
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
BuildGidx(&shard, kNRows, kNCols);
xgboost::SimpleLCG gen;
@ -202,7 +201,8 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
thrust::device_pointer_cast(shard.ridx.Current()),
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
builder.Build(&shard, 0);
shard.use_shared_memory_histograms = use_shared_memory_histograms;
shard.BuildHist(0);
DeviceHistogram<GradientSumT> d_hist = shard.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
@ -224,17 +224,13 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
}
TEST(GpuHist, BuildHistGlobalMem) {
GlobalMemHistBuilder<GradientPairPrecise> double_builder;
TestBuildHist(double_builder);
GlobalMemHistBuilder<GradientPair> float_builder;
TestBuildHist(float_builder);
TestBuildHist<GradientPairPrecise>(false);
TestBuildHist<GradientPair>(false);
}
TEST(GpuHist, BuildHistSharedMem) {
SharedMemHistBuilder<GradientPairPrecise> double_builder;
TestBuildHist(double_builder);
SharedMemHistBuilder<GradientPair> float_builder;
TestBuildHist(float_builder);
TestBuildHist<GradientPairPrecise>(true);
TestBuildHist<GradientPair>(true);
}
common::HistCutMatrix GetHostCutMatrix () {