Cache GPU histogram kernel configuration. (#10538)

This commit is contained in:
Jiaming Yuan
2024-07-04 15:38:59 +08:00
committed by GitHub
parent cd1d108c7d
commit 620b2b155a
6 changed files with 185 additions and 118 deletions

View File

@@ -1,11 +1,10 @@
/**
* Copyright 2020-2023, XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <vector>
#include "../../../../src/common/categorical.h"
#include "../../../../src/tree/gpu_hist/histogram.cuh"
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../../../src/tree/param.h" // TrainParam
@@ -13,7 +12,7 @@
#include "../../helpers.h"
namespace xgboost::tree {
void TestDeterministicHistogram(bool is_dense, int shm_size) {
void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) {
Context ctx = MakeCUDACtx(0);
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
float constexpr kLower = -1e-2, kUpper = 1e2;
@@ -25,35 +24,37 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl();
tree::RowPartitioner row_partitioner(FstCU(), kRows);
tree::RowPartitioner row_partitioner(ctx.Device(), kRows);
auto ridx = row_partitioner.GetRows(0);
int num_bins = kBins * kCols;
bst_bin_t num_bins = kBins * kCols;
dh::device_vector<GradientPairInt64> histogram(num_bins);
auto d_histogram = dh::ToSpan(histogram);
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
gpair.SetDevice(FstCU());
gpair.SetDevice(ctx.Device());
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
sizeof(GradientPairInt64));
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64));
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser);
std::vector<GradientPairInt64> histogram_h(num_bins);
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
num_bins * sizeof(GradientPairInt64),
cudaMemcpyDeviceToHost));
num_bins * sizeof(GradientPairInt64), cudaMemcpyDeviceToHost));
for (size_t i = 0; i < kRounds; ++i) {
for (std::size_t i = 0; i < kRounds; ++i) {
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
auto d_new_histogram = dh::ToSpan(new_histogram);
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser);
std::vector<GradientPairInt64> new_histogram_h(num_bins);
@@ -68,14 +69,16 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
{
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
gpair.SetDevice(FstCU());
gpair.SetDevice(ctx.Device());
// Use a single feature group to compute the baseline.
FeatureGroups single_group(page->Cuts());
dh::device_vector<GradientPairInt64> baseline(num_bins);
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
single_group.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser);
std::vector<GradientPairInt64> baseline_h(num_bins);
@@ -96,7 +99,9 @@ TEST(Histogram, GPUDeterministic) {
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
for (bool is_dense : is_dense_array) {
for (int shm_size : shm_sizes) {
TestDeterministicHistogram(is_dense, shm_size);
for (bool force_global : {true, false}) {
TestDeterministicHistogram(is_dense, shm_size, force_global);
}
}
}
}
@@ -136,7 +141,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
for (auto const &batch : cat_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser);
}
@@ -150,7 +157,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
for (auto const &batch : encode_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser);
}

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost contributors
* Copyright 2017-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/device_vector.h>
@@ -22,12 +22,8 @@
#include "xgboost/context.h"
#include "xgboost/json.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal
#endif // defined(XGBOOST_USE_FEDERATED)
namespace xgboost::tree {
TEST(GpuHist, DeviceHistogram) {
TEST(GpuHist, DeviceHistogramStorage) {
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
dh::safe_cuda(cudaSetDevice(0));
constexpr size_t kNBins = 128;
@@ -102,17 +98,17 @@ void TestBuildHist(bool use_shared_memory_histograms) {
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows);
for (auto &gp : gpair.HostVector()) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
gp = GradientPair(grad, hess);
for (auto& gp : gpair.HostVector()) {
float grad = dist(&gen);
float hess = dist(&gen);
gp = GradientPair{grad, hess};
}
gpair.SetDevice(DeviceOrd::CUDA(0));
gpair.SetDevice(ctx.Device());
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
maker.row_partitioner = std::make_unique<RowPartitioner>(FstCU(), kNRows);
thrust::host_vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.HostVector());
maker.row_partitioner = std::make_unique<RowPartitioner>(ctx.Device(), kNRows);
maker.hist.Init(FstCU(), page->Cuts().TotalBins());
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
maker.hist.AllocateHistograms({0});
maker.gpair = gpair.DeviceSpan();
@@ -121,10 +117,13 @@ void TestBuildHist(bool use_shared_memory_histograms) {
maker.InitFeatureGroupsOnce();
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(DeviceOrd::CUDA(0)),
maker.feature_groups->DeviceAccessor(DeviceOrd::CUDA(0)), gpair.DeviceSpan(),
DeviceHistogramBuilder builder;
builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()),
!use_shared_memory_histograms);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
*maker.quantiser, !use_shared_memory_histograms);
*maker.quantiser);
DeviceHistogramStorage<>& d_hist = maker.hist;