Cache GPU histogram kernel configuration. (#10538)

This commit is contained in:
Jiaming Yuan 2024-07-04 15:38:59 +08:00 committed by GitHub
parent cd1d108c7d
commit 620b2b155a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 185 additions and 118 deletions

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2020 by XGBoost Contributors * Copyright 2020-2024, XGBoost Contributors
*/ */
#include <xgboost/base.h> #include <xgboost/base.h>
@ -8,12 +8,9 @@
#include "feature_groups.cuh" #include "feature_groups.cuh"
#include "../../common/device_helpers.cuh"
#include "../../common/hist_util.h" #include "../../common/hist_util.h"
namespace xgboost { namespace xgboost::tree {
namespace tree {
FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
size_t shm_size, size_t bin_size) { size_t shm_size, size_t bin_size) {
// Only use a single feature group for sparse matrices. // Only use a single feature group for sparse matrices.
@ -59,6 +56,4 @@ void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) {
max_group_bins = cuts.TotalBins(); max_group_bins = cuts.TotalBins();
} }
} // namespace xgboost::tree
} // namespace tree
} // namespace xgboost

View File

@ -5,8 +5,7 @@
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include <algorithm> #include <algorithm>
#include <cstdint> // uint32_t #include <cstdint> // uint32_t, int32_t
#include <limits>
#include "../../collective/aggregator.h" #include "../../collective/aggregator.h"
#include "../../common/deterministic.cuh" #include "../../common/deterministic.cuh"
@ -128,7 +127,7 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest,
} }
template <int kBlockThreads, int kItemsPerThread, template <int kBlockThreads, int kItemsPerThread,
int kItemsPerTile = kBlockThreads* kItemsPerThread> int kItemsPerTile = kBlockThreads * kItemsPerThread>
class HistogramAgent { class HistogramAgent {
GradientPairInt64* smem_arr_; GradientPairInt64* smem_arr_;
GradientPairInt64* d_node_hist_; GradientPairInt64* d_node_hist_;
@ -244,53 +243,82 @@ __global__ void __launch_bounds__(kBlockThreads)
extern __shared__ char smem[]; extern __shared__ char smem[];
const FeatureGroup group = feature_groups[blockIdx.y]; const FeatureGroup group = feature_groups[blockIdx.y];
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem); auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>( auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(smem_arr, d_node_hist, group, matrix,
smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair); d_ridx, rounding, d_gpair);
if (use_shared_memory_histograms) { if (use_shared_memory_histograms) {
agent.BuildHistogramWithShared(); agent.BuildHistogramWithShared();
} else { } else {
agent.BuildHistogramWithGlobal(); agent.BuildHistogramWithGlobal();
} }
} }
namespace {
constexpr std::int32_t kBlockThreads = 1024;
constexpr std::int32_t kItemsPerThread = 8;
constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread; }
} // namespace
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, // Use auto deduction guide to workaround compiler error.
FeatureGroupsAccessor const& feature_groups, template <auto Global = SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>,
common::Span<GradientPair const> gpair, auto Shared = SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>>
common::Span<const uint32_t> d_ridx, struct HistogramKernel {
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding, decltype(Global) global_kernel{SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>};
decltype(Shared) shared_kernel{SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>};
bool shared{false};
std::uint32_t grid_size{0};
std::size_t smem_size{0};
HistogramKernel(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory) { bool force_global_memory) {
// decide whether to use shared memory // Decide whether to use shared memory
int device = 0; // Opt into maximum shared memory for the kernel if necessary
dh::safe_cuda(cudaGetDevice(&device)); std::size_t max_shared_memory = dh::MaxSharedMemoryOptin(ctx->Ordinal());
// opt into maximum shared memory for the kernel if necessary
size_t max_shared_memory = dh::MaxSharedMemoryOptin(device);
size_t smem_size = this->smem_size = sizeof(GradientPairInt64) * feature_groups.max_group_bins;
sizeof(GradientPairInt64) * feature_groups.max_group_bins; this->shared = !force_global_memory && smem_size <= max_shared_memory;
bool shared = !force_global_memory && smem_size <= max_shared_memory; this->smem_size = this->shared ? this->smem_size : 0;
smem_size = shared ? smem_size : 0;
constexpr int kBlockThreads = 1024; auto init = [&](auto& kernel) {
constexpr int kItemsPerThread = 8; if (this->shared) {
constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread;
auto runit = [&, kMinItemsPerBlock = kItemsPerTile](auto kernel) {
if (shared) {
dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_memory)); max_shared_memory));
} }
// determine the launch configuration // determine the launch configuration
int num_groups = feature_groups.NumGroups(); std::int32_t num_groups = feature_groups.NumGroups();
int n_mps = 0; std::int32_t n_mps = 0;
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, ctx->Ordinal()));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
kBlockThreads, smem_size));
// This gives the number of blocks to keep the device occupied
// Use this as the maximum number of blocks
unsigned grid_size = n_blocks_per_mp * n_mps;
std::int32_t n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
kBlockThreads, this->smem_size));
// This gives the number of blocks to keep the device occupied Use this as the
// maximum number of blocks
this->grid_size = n_blocks_per_mp * n_mps;
};
init(this->global_kernel);
init(this->shared_kernel);
}
};
class DeviceHistogramBuilderImpl {
std::unique_ptr<HistogramKernel<>> kernel_{nullptr};
bool force_global_memory_{false};
public:
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory) {
this->kernel_ = std::make_unique<HistogramKernel<>>(ctx, feature_groups, force_global_memory);
this->force_global_memory_ = force_global_memory;
}
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> d_ridx,
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding) {
CHECK(kernel_);
// Otherwise launch blocks such that each block has a minimum amount of work to do // Otherwise launch blocks such that each block has a minimum amount of work to do
// There are fixed costs to launching each block, e.g. zeroing shared memory // There are fixed costs to launching each block, e.g. zeroing shared memory
// The below amount of minimum work was found by experimentation // The below amount of minimum work was found by experimentation
@ -300,20 +328,41 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const&
// Allocate number of blocks such that each block has about kMinItemsPerBlock work // Allocate number of blocks such that each block has about kMinItemsPerBlock work
// Up to a maximum where the device is saturated // Up to a maximum where the device is saturated
grid_size = std::min(grid_size, static_cast<std::uint32_t>( auto constexpr kMinItemsPerBlock = ItemsPerTile();
common::DivRoundUp(items_per_group, kMinItemsPerBlock))); auto grid_size = std::min(kernel_->grid_size, static_cast<std::uint32_t>(common::DivRoundUp(
items_per_group, kMinItemsPerBlock)));
dh::LaunchKernel {dim3(grid_size, num_groups), static_cast<uint32_t>(kBlockThreads), smem_size, if (this->force_global_memory_ || !this->kernel_->shared) {
ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(), dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
gpair.data(), rounding); static_cast<uint32_t>(kBlockThreads), kernel_->smem_size,
}; ctx->Stream()}(kernel_->global_kernel, matrix, feature_groups, d_ridx,
histogram.data(), gpair.data(), rounding);
if (shared) {
runit(SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>);
} else { } else {
runit(SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>); dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size,
ctx->Stream()}(kernel_->shared_kernel, matrix, feature_groups, d_ridx,
histogram.data(), gpair.data(), rounding);
} }
}
};
dh::safe_cuda(cudaGetLastError()); DeviceHistogramBuilder::DeviceHistogramBuilder()
: p_impl_{std::make_unique<DeviceHistogramBuilderImpl>()} {}
DeviceHistogramBuilder::~DeviceHistogramBuilder() = default;
void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory) {
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
}
void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram,
GradientQuantiser rounding) {
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding);
} }
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -1,17 +1,18 @@
/*! /**
* Copyright 2020-2021 by XGBoost Contributors * Copyright 2020-2024, XGBoost Contributors
*/ */
#ifndef HISTOGRAM_CUH_ #ifndef HISTOGRAM_CUH_
#define HISTOGRAM_CUH_ #define HISTOGRAM_CUH_
#include <thrust/transform.h> #include <memory> // for unique_ptr
#include "../../common/cuda_context.cuh" #include "../../common/cuda_context.cuh" // for CUDAContext
#include "../../data/ellpack_page.cuh" #include "../../data/ellpack_page.cuh" // for EllpackDeviceAccessor
#include "feature_groups.cuh" #include "feature_groups.cuh" // for FeatureGroupsAccessor
#include "xgboost/base.h" // for GradientPair, GradientPairInt64
namespace xgboost { #include "xgboost/context.h" // for Context
namespace tree { #include "xgboost/span.h" // for Span
namespace xgboost::tree {
/** /**
* \brief An atomicAdd designed for gradient pair with better performance. For general * \brief An atomicAdd designed for gradient pair with better performance. For general
* int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing. * int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing.
@ -32,7 +33,7 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) {
} }
class GradientQuantiser { class GradientQuantiser {
private: private:
/* Convert gradient to fixed point representation. */ /* Convert gradient to fixed point representation. */
GradientPairPrecise to_fixed_point_; GradientPairPrecise to_fixed_point_;
/* Convert fixed point representation back to floating point. */ /* Convert fixed point representation back to floating point. */
@ -59,13 +60,23 @@ private:
} }
}; };
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, class DeviceHistogramBuilderImpl;
class DeviceHistogramBuilder {
std::unique_ptr<DeviceHistogramBuilderImpl> p_impl_;
public:
DeviceHistogramBuilder();
~DeviceHistogramBuilder();
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory);
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups, FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair, common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx, common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding, common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
bool force_global_memory = false); };
} // namespace tree } // namespace xgboost::tree
} // namespace xgboost
#endif // HISTOGRAM_CUH_ #endif // HISTOGRAM_CUH_

View File

@ -162,6 +162,8 @@ struct GPUHistMakerDevice {
std::shared_ptr<common::ColumnSampler> column_sampler_; std::shared_ptr<common::ColumnSampler> column_sampler_;
MetaInfo const& info_; MetaInfo const& info_;
DeviceHistogramBuilder histogram_;
public: public:
EllpackPageImpl const* page{nullptr}; EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
@ -256,6 +258,8 @@ struct GPUHistMakerDevice {
hist.Reset(); hist.Reset();
this->InitFeatureGroupsOnce(); this->InitFeatureGroupsOnce();
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
} }
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
@ -340,7 +344,7 @@ struct GPUHistMakerDevice {
void BuildHist(int nidx) { void BuildHist(int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx); auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
d_node_hist, *quantiser); d_node_hist, *quantiser);
} }

View File

@ -1,11 +1,10 @@
/** /**
* Copyright 2020-2023, XGBoost Contributors * Copyright 2020-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "../../../../src/common/categorical.h"
#include "../../../../src/tree/gpu_hist/histogram.cuh" #include "../../../../src/tree/gpu_hist/histogram.cuh"
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../../../src/tree/param.h" // TrainParam #include "../../../../src/tree/param.h" // TrainParam
@ -13,7 +12,7 @@
#include "../../helpers.h" #include "../../helpers.h"
namespace xgboost::tree { namespace xgboost::tree {
void TestDeterministicHistogram(bool is_dense, int shm_size) { void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) {
Context ctx = MakeCUDACtx(0); Context ctx = MakeCUDACtx(0);
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
float constexpr kLower = -1e-2, kUpper = 1e2; float constexpr kLower = -1e-2, kUpper = 1e2;
@ -25,35 +24,37 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) { for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl(); auto* page = batch.Impl();
tree::RowPartitioner row_partitioner(FstCU(), kRows); tree::RowPartitioner row_partitioner(ctx.Device(), kRows);
auto ridx = row_partitioner.GetRows(0); auto ridx = row_partitioner.GetRows(0);
int num_bins = kBins * kCols; bst_bin_t num_bins = kBins * kCols;
dh::device_vector<GradientPairInt64> histogram(num_bins); dh::device_vector<GradientPairInt64> histogram(num_bins);
auto d_histogram = dh::ToSpan(histogram); auto d_histogram = dh::ToSpan(histogram);
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
gpair.SetDevice(FstCU()); gpair.SetDevice(ctx.Device());
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64));
sizeof(GradientPairInt64));
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()), DeviceHistogramBuilder builder;
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx, builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser); d_histogram, quantiser);
std::vector<GradientPairInt64> histogram_h(num_bins); std::vector<GradientPairInt64> histogram_h(num_bins);
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(), dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
num_bins * sizeof(GradientPairInt64), num_bins * sizeof(GradientPairInt64), cudaMemcpyDeviceToHost));
cudaMemcpyDeviceToHost));
for (size_t i = 0; i < kRounds; ++i) { for (std::size_t i = 0; i < kRounds; ++i) {
dh::device_vector<GradientPairInt64> new_histogram(num_bins); dh::device_vector<GradientPairInt64> new_histogram(num_bins);
auto d_new_histogram = dh::ToSpan(new_histogram); auto d_new_histogram = dh::ToSpan(new_histogram);
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()), DeviceHistogramBuilder builder;
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx, builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser); d_new_histogram, quantiser);
std::vector<GradientPairInt64> new_histogram_h(num_bins); std::vector<GradientPairInt64> new_histogram_h(num_bins);
@ -68,14 +69,16 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
{ {
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
gpair.SetDevice(FstCU()); gpair.SetDevice(ctx.Device());
// Use a single feature group to compute the baseline. // Use a single feature group to compute the baseline.
FeatureGroups single_group(page->Cuts()); FeatureGroups single_group(page->Cuts());
dh::device_vector<GradientPairInt64> baseline(num_bins); dh::device_vector<GradientPairInt64> baseline(num_bins);
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()), DeviceHistogramBuilder builder;
single_group.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx, builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser); dh::ToSpan(baseline), quantiser);
std::vector<GradientPairInt64> baseline_h(num_bins); std::vector<GradientPairInt64> baseline_h(num_bins);
@ -96,7 +99,9 @@ TEST(Histogram, GPUDeterministic) {
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
for (bool is_dense : is_dense_array) { for (bool is_dense : is_dense_array) {
for (int shm_size : shm_sizes) { for (int shm_size : shm_sizes) {
TestDeterministicHistogram(is_dense, shm_size); for (bool force_global : {true, false}) {
TestDeterministicHistogram(is_dense, shm_size, force_global);
}
} }
} }
} }
@ -136,7 +141,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
for (auto const &batch : cat_m->GetBatches<EllpackPage>(&ctx, batch_param)) { for (auto const &batch : cat_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl(); auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts()); FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser); dh::ToSpan(cat_hist), quantiser);
} }
@ -150,7 +157,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
for (auto const &batch : encode_m->GetBatches<EllpackPage>(&ctx, batch_param)) { for (auto const &batch : encode_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl(); auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts()); FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser); dh::ToSpan(encode_hist), quantiser);
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2017-2023 by XGBoost contributors * Copyright 2017-2024, XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
@ -22,12 +22,8 @@
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal
#endif // defined(XGBOOST_USE_FEDERATED)
namespace xgboost::tree { namespace xgboost::tree {
TEST(GpuHist, DeviceHistogram) { TEST(GpuHist, DeviceHistogramStorage) {
// Ensures that node allocates correctly after reaching `kStopGrowingSize`. // Ensures that node allocates correctly after reaching `kStopGrowingSize`.
dh::safe_cuda(cudaSetDevice(0)); dh::safe_cuda(cudaSetDevice(0));
constexpr size_t kNBins = 128; constexpr size_t kNBins = 128;
@ -102,17 +98,17 @@ void TestBuildHist(bool use_shared_memory_histograms) {
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f); xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows); HostDeviceVector<GradientPair> gpair(kNRows);
for (auto &gp : gpair.HostVector()) { for (auto& gp : gpair.HostVector()) {
bst_float grad = dist(&gen); float grad = dist(&gen);
bst_float hess = dist(&gen); float hess = dist(&gen);
gp = GradientPair(grad, hess); gp = GradientPair{grad, hess};
} }
gpair.SetDevice(DeviceOrd::CUDA(0)); gpair.SetDevice(ctx.Device());
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector()); thrust::host_vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.HostVector());
maker.row_partitioner = std::make_unique<RowPartitioner>(FstCU(), kNRows); maker.row_partitioner = std::make_unique<RowPartitioner>(ctx.Device(), kNRows);
maker.hist.Init(FstCU(), page->Cuts().TotalBins()); maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
maker.hist.AllocateHistograms({0}); maker.hist.AllocateHistograms({0});
maker.gpair = gpair.DeviceSpan(); maker.gpair = gpair.DeviceSpan();
@ -121,10 +117,13 @@ void TestBuildHist(bool use_shared_memory_histograms) {
maker.InitFeatureGroupsOnce(); maker.InitFeatureGroupsOnce();
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(DeviceOrd::CUDA(0)), DeviceHistogramBuilder builder;
maker.feature_groups->DeviceAccessor(DeviceOrd::CUDA(0)), gpair.DeviceSpan(), builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()),
!use_shared_memory_histograms);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0), maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
*maker.quantiser, !use_shared_memory_histograms); *maker.quantiser);
DeviceHistogramStorage<>& d_hist = maker.hist; DeviceHistogramStorage<>& d_hist = maker.hist;