Refactor histogram building code for gpu_hist (#4528)
This commit is contained in:
parent
399fabed49
commit
23a10c8339
@ -515,28 +515,38 @@ __global__ void CompressBinEllpackKernel(
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
__global__ void SharedMemHistKernel(ELLPackMatrix matrix, const bst_uint* d_ridx,
|
||||
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
||||
const bst_uint* d_ridx,
|
||||
GradientSumT* d_node_hist,
|
||||
const GradientPair* d_gpair,
|
||||
size_t segment_begin, size_t n_elements) {
|
||||
size_t segment_begin, size_t n_elements,
|
||||
bool use_shared_memory_histograms) {
|
||||
extern __shared__ char smem[];
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
for (auto i :
|
||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||
smem_arr[i] = GradientSumT();
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
if (use_shared_memory_histograms) {
|
||||
dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT());
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
|
||||
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||
int gidx =
|
||||
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||
if (gidx != matrix.null_gidx_value) {
|
||||
AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]);
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
GradientSumT* atomic_add_ptr =
|
||||
use_shared_memory_histograms ? smem_arr : d_node_hist;
|
||||
AtomicAddGpair(atomic_add_ptr + gidx, d_gpair[ridx]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (auto i :
|
||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
||||
|
||||
if (use_shared_memory_histograms) {
|
||||
// Write shared memory back to global memory
|
||||
__syncthreads();
|
||||
for (auto i :
|
||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -620,16 +630,6 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct DeviceShard;
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct GPUHistBuilderBase {
|
||||
public:
|
||||
virtual void Build(DeviceShard<GradientSumT>* shard, int idx) = 0;
|
||||
virtual ~GPUHistBuilderBase() = default;
|
||||
};
|
||||
|
||||
// Manage memory for a single GPU
|
||||
template <typename GradientSumT>
|
||||
struct DeviceShard {
|
||||
@ -679,6 +679,7 @@ struct DeviceShard {
|
||||
|
||||
TrainParam param;
|
||||
bool prediction_cache_initialised;
|
||||
bool use_shared_memory_histograms {false};
|
||||
|
||||
dh::CubMemory temp_memory;
|
||||
dh::PinnedMemory pinned_memory;
|
||||
@ -689,8 +690,6 @@ struct DeviceShard {
|
||||
std::vector<ValueConstraint> node_value_constraints;
|
||||
common::ColumnSampler column_sampler;
|
||||
|
||||
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
|
||||
|
||||
using ExpandQueue =
|
||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
@ -870,7 +869,27 @@ struct DeviceShard {
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
hist.AllocateHistogram(nidx);
|
||||
hist_builder->Build(this, nidx);
|
||||
auto segment = ridx_segments[nidx];
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = ridx.Current();
|
||||
auto d_gpair = gpair.data();
|
||||
|
||||
auto n_elements = segment.Size() * ellpack_matrix.row_stride;
|
||||
|
||||
const size_t smem_size =
|
||||
use_shared_memory_histograms
|
||||
? sizeof(GradientSumT) * ellpack_matrix.BinCount()
|
||||
: 0;
|
||||
const int items_per_thread = 8;
|
||||
const int block_threads = 256;
|
||||
const int grid_size = static_cast<int>(
|
||||
dh::DivRoundUp(n_elements, items_per_thread * block_threads));
|
||||
if (grid_size <= 0) {
|
||||
return;
|
||||
}
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin,
|
||||
n_elements, use_shared_memory_histograms);
|
||||
}
|
||||
|
||||
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||
@ -1084,7 +1103,6 @@ struct DeviceShard {
|
||||
this->AllReduceHist(subtraction_trick_nidx, reducer);
|
||||
}
|
||||
}
|
||||
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
@ -1209,55 +1227,6 @@ struct DeviceShard {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
|
||||
auto segment = shard->ridx_segments[nidx];
|
||||
auto segment_begin = segment.begin;
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = shard->ridx.Current();
|
||||
auto d_gpair = shard->gpair.data();
|
||||
|
||||
auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
|
||||
const size_t smem_size = sizeof(GradientSumT) * shard->ellpack_matrix.BinCount();
|
||||
const int items_per_thread = 8;
|
||||
const int block_threads = 256;
|
||||
const int grid_size =
|
||||
static_cast<int>(dh::DivRoundUp(n_elements,
|
||||
items_per_thread * block_threads));
|
||||
if (grid_size <= 0) {
|
||||
return;
|
||||
}
|
||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||
shard->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair,
|
||||
segment_begin, n_elements);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
|
||||
Segment segment = shard->ridx_segments[nidx];
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
||||
bst_uint* d_ridx = shard->ridx.Current();
|
||||
GradientPair* d_gpair = shard->gpair.data();
|
||||
|
||||
size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
auto d_matrix = shard->ellpack_matrix;
|
||||
|
||||
dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) {
|
||||
int ridx = d_ridx[(idx / d_matrix.row_stride) + segment.begin];
|
||||
// lookup the index (bin) of histogram.
|
||||
int gidx = d_matrix.gidx_iter[ridx * d_matrix.row_stride + idx % d_matrix.row_stride];
|
||||
|
||||
if (gidx != d_matrix.null_gidx_value) {
|
||||
AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
|
||||
@ -1317,9 +1286,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
|
||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||
if (histogram_size <= max_smem) {
|
||||
hist_builder.reset(new SharedMemHistBuilder<GradientSumT>);
|
||||
} else {
|
||||
hist_builder.reset(new GlobalMemHistBuilder<GradientSumT>);
|
||||
use_shared_memory_histograms = true;
|
||||
}
|
||||
|
||||
// Init histogram
|
||||
|
||||
@ -162,7 +162,7 @@ std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
int const kNRows = 16, kNCols = 8;
|
||||
|
||||
TrainParam param;
|
||||
@ -170,7 +170,6 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
param.max_leaves = 0;
|
||||
|
||||
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
|
||||
|
||||
BuildGidx(&shard, kNRows, kNCols);
|
||||
|
||||
xgboost::SimpleLCG gen;
|
||||
@ -202,7 +201,8 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
thrust::device_pointer_cast(shard.ridx.Current()),
|
||||
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
|
||||
|
||||
builder.Build(&shard, 0);
|
||||
shard.use_shared_memory_histograms = use_shared_memory_histograms;
|
||||
shard.BuildHist(0);
|
||||
DeviceHistogram<GradientSumT> d_hist = shard.hist;
|
||||
|
||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||
@ -224,17 +224,13 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, BuildHistGlobalMem) {
|
||||
GlobalMemHistBuilder<GradientPairPrecise> double_builder;
|
||||
TestBuildHist(double_builder);
|
||||
GlobalMemHistBuilder<GradientPair> float_builder;
|
||||
TestBuildHist(float_builder);
|
||||
TestBuildHist<GradientPairPrecise>(false);
|
||||
TestBuildHist<GradientPair>(false);
|
||||
}
|
||||
|
||||
TEST(GpuHist, BuildHistSharedMem) {
|
||||
SharedMemHistBuilder<GradientPairPrecise> double_builder;
|
||||
TestBuildHist(double_builder);
|
||||
SharedMemHistBuilder<GradientPair> float_builder;
|
||||
TestBuildHist(float_builder);
|
||||
TestBuildHist<GradientPairPrecise>(true);
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
common::HistCutMatrix GetHostCutMatrix () {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user