Refactor histogram building code for gpu_hist (#4528)
This commit is contained in:
parent
399fabed49
commit
23a10c8339
@ -515,30 +515,40 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
__global__ void SharedMemHistKernel(ELLPackMatrix matrix, const bst_uint* d_ridx,
|
__global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
||||||
|
const bst_uint* d_ridx,
|
||||||
GradientSumT* d_node_hist,
|
GradientSumT* d_node_hist,
|
||||||
const GradientPair* d_gpair,
|
const GradientPair* d_gpair,
|
||||||
size_t segment_begin, size_t n_elements) {
|
size_t segment_begin, size_t n_elements,
|
||||||
|
bool use_shared_memory_histograms) {
|
||||||
extern __shared__ char smem[];
|
extern __shared__ char smem[];
|
||||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||||
for (auto i :
|
if (use_shared_memory_histograms) {
|
||||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT());
|
||||||
smem_arr[i] = GradientSumT();
|
|
||||||
}
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
}
|
||||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||||
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
|
int ridx = d_ridx[idx / matrix.row_stride + segment_begin];
|
||||||
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
int gidx =
|
||||||
|
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||||
if (gidx != matrix.null_gidx_value) {
|
if (gidx != matrix.null_gidx_value) {
|
||||||
AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]);
|
// If we are not using shared memory, accumulate the values directly into
|
||||||
|
// global memory
|
||||||
|
GradientSumT* atomic_add_ptr =
|
||||||
|
use_shared_memory_histograms ? smem_arr : d_node_hist;
|
||||||
|
AtomicAddGpair(atomic_add_ptr + gidx, d_gpair[ridx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (use_shared_memory_histograms) {
|
||||||
|
// Write shared memory back to global memory
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (auto i :
|
for (auto i :
|
||||||
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
dh::BlockStrideRange(static_cast<size_t>(0), matrix.BinCount())) {
|
||||||
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Segment {
|
struct Segment {
|
||||||
size_t begin;
|
size_t begin;
|
||||||
@ -620,16 +630,6 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
struct DeviceShard;
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
struct GPUHistBuilderBase {
|
|
||||||
public:
|
|
||||||
virtual void Build(DeviceShard<GradientSumT>* shard, int idx) = 0;
|
|
||||||
virtual ~GPUHistBuilderBase() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Manage memory for a single GPU
|
// Manage memory for a single GPU
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
struct DeviceShard {
|
struct DeviceShard {
|
||||||
@ -679,6 +679,7 @@ struct DeviceShard {
|
|||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
bool prediction_cache_initialised;
|
bool prediction_cache_initialised;
|
||||||
|
bool use_shared_memory_histograms {false};
|
||||||
|
|
||||||
dh::CubMemory temp_memory;
|
dh::CubMemory temp_memory;
|
||||||
dh::PinnedMemory pinned_memory;
|
dh::PinnedMemory pinned_memory;
|
||||||
@ -689,8 +690,6 @@ struct DeviceShard {
|
|||||||
std::vector<ValueConstraint> node_value_constraints;
|
std::vector<ValueConstraint> node_value_constraints;
|
||||||
common::ColumnSampler column_sampler;
|
common::ColumnSampler column_sampler;
|
||||||
|
|
||||||
std::unique_ptr<GPUHistBuilderBase<GradientSumT>> hist_builder;
|
|
||||||
|
|
||||||
using ExpandQueue =
|
using ExpandQueue =
|
||||||
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
|
||||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||||
@ -870,7 +869,27 @@ struct DeviceShard {
|
|||||||
|
|
||||||
void BuildHist(int nidx) {
|
void BuildHist(int nidx) {
|
||||||
hist.AllocateHistogram(nidx);
|
hist.AllocateHistogram(nidx);
|
||||||
hist_builder->Build(this, nidx);
|
auto segment = ridx_segments[nidx];
|
||||||
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
|
auto d_ridx = ridx.Current();
|
||||||
|
auto d_gpair = gpair.data();
|
||||||
|
|
||||||
|
auto n_elements = segment.Size() * ellpack_matrix.row_stride;
|
||||||
|
|
||||||
|
const size_t smem_size =
|
||||||
|
use_shared_memory_histograms
|
||||||
|
? sizeof(GradientSumT) * ellpack_matrix.BinCount()
|
||||||
|
: 0;
|
||||||
|
const int items_per_thread = 8;
|
||||||
|
const int block_threads = 256;
|
||||||
|
const int grid_size = static_cast<int>(
|
||||||
|
dh::DivRoundUp(n_elements, items_per_thread * block_threads));
|
||||||
|
if (grid_size <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
||||||
|
ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin,
|
||||||
|
n_elements, use_shared_memory_histograms);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||||
@ -1084,7 +1103,6 @@ struct DeviceShard {
|
|||||||
this->AllReduceHist(subtraction_trick_nidx, reducer);
|
this->AllReduceHist(subtraction_trick_nidx, reducer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||||
RegTree& tree = *p_tree;
|
RegTree& tree = *p_tree;
|
||||||
|
|
||||||
@ -1209,55 +1227,6 @@ struct DeviceShard {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
|
||||||
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
|
|
||||||
auto segment = shard->ridx_segments[nidx];
|
|
||||||
auto segment_begin = segment.begin;
|
|
||||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
|
|
||||||
auto d_ridx = shard->ridx.Current();
|
|
||||||
auto d_gpair = shard->gpair.data();
|
|
||||||
|
|
||||||
auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
|
||||||
|
|
||||||
const size_t smem_size = sizeof(GradientSumT) * shard->ellpack_matrix.BinCount();
|
|
||||||
const int items_per_thread = 8;
|
|
||||||
const int block_threads = 256;
|
|
||||||
const int grid_size =
|
|
||||||
static_cast<int>(dh::DivRoundUp(n_elements,
|
|
||||||
items_per_thread * block_threads));
|
|
||||||
if (grid_size <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
|
||||||
shard->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair,
|
|
||||||
segment_begin, n_elements);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
|
||||||
struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
|
||||||
void Build(DeviceShard<GradientSumT>* shard, int nidx) override {
|
|
||||||
Segment segment = shard->ridx_segments[nidx];
|
|
||||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
|
||||||
bst_uint* d_ridx = shard->ridx.Current();
|
|
||||||
GradientPair* d_gpair = shard->gpair.data();
|
|
||||||
|
|
||||||
size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
|
||||||
auto d_matrix = shard->ellpack_matrix;
|
|
||||||
|
|
||||||
dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) {
|
|
||||||
int ridx = d_ridx[(idx / d_matrix.row_stride) + segment.begin];
|
|
||||||
// lookup the index (bin) of histogram.
|
|
||||||
int gidx = d_matrix.gidx_iter[ridx * d_matrix.row_stride + idx % d_matrix.row_stride];
|
|
||||||
|
|
||||||
if (gidx != d_matrix.null_gidx_value) {
|
|
||||||
AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||||
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
|
const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) {
|
||||||
@ -1317,9 +1286,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
|||||||
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
|
auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back();
|
||||||
auto max_smem = dh::MaxSharedMemory(device_id);
|
auto max_smem = dh::MaxSharedMemory(device_id);
|
||||||
if (histogram_size <= max_smem) {
|
if (histogram_size <= max_smem) {
|
||||||
hist_builder.reset(new SharedMemHistBuilder<GradientSumT>);
|
use_shared_memory_histograms = true;
|
||||||
} else {
|
|
||||||
hist_builder.reset(new GlobalMemHistBuilder<GradientSumT>);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
|
|||||||
@ -162,7 +162,7 @@ std::vector<GradientPairPrecise> GetHostHistGpair() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
void TestBuildHist(bool use_shared_memory_histograms) {
|
||||||
int const kNRows = 16, kNCols = 8;
|
int const kNRows = 16, kNCols = 8;
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
@ -170,7 +170,6 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
|||||||
param.max_leaves = 0;
|
param.max_leaves = 0;
|
||||||
|
|
||||||
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
|
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols);
|
||||||
|
|
||||||
BuildGidx(&shard, kNRows, kNCols);
|
BuildGidx(&shard, kNRows, kNCols);
|
||||||
|
|
||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
@ -202,7 +201,8 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
|||||||
thrust::device_pointer_cast(shard.ridx.Current()),
|
thrust::device_pointer_cast(shard.ridx.Current()),
|
||||||
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
|
thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size()));
|
||||||
|
|
||||||
builder.Build(&shard, 0);
|
shard.use_shared_memory_histograms = use_shared_memory_histograms;
|
||||||
|
shard.BuildHist(0);
|
||||||
DeviceHistogram<GradientSumT> d_hist = shard.hist;
|
DeviceHistogram<GradientSumT> d_hist = shard.hist;
|
||||||
|
|
||||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||||
@ -224,17 +224,13 @@ void TestBuildHist(GPUHistBuilderBase<GradientSumT>& builder) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, BuildHistGlobalMem) {
|
TEST(GpuHist, BuildHistGlobalMem) {
|
||||||
GlobalMemHistBuilder<GradientPairPrecise> double_builder;
|
TestBuildHist<GradientPairPrecise>(false);
|
||||||
TestBuildHist(double_builder);
|
TestBuildHist<GradientPair>(false);
|
||||||
GlobalMemHistBuilder<GradientPair> float_builder;
|
|
||||||
TestBuildHist(float_builder);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, BuildHistSharedMem) {
|
TEST(GpuHist, BuildHistSharedMem) {
|
||||||
SharedMemHistBuilder<GradientPairPrecise> double_builder;
|
TestBuildHist<GradientPairPrecise>(true);
|
||||||
TestBuildHist(double_builder);
|
TestBuildHist<GradientPair>(true);
|
||||||
SharedMemHistBuilder<GradientPair> float_builder;
|
|
||||||
TestBuildHist(float_builder);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
common::HistCutMatrix GetHostCutMatrix () {
|
common::HistCutMatrix GetHostCutMatrix () {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user