Cache GPU histogram kernel configuration. (#10538)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
|
||||
#include <xgboost/base.h>
|
||||
@@ -8,12 +8,9 @@
|
||||
|
||||
#include "feature_groups.cuh"
|
||||
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../common/hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
namespace xgboost::tree {
|
||||
FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
|
||||
size_t shm_size, size_t bin_size) {
|
||||
// Only use a single feature group for sparse matrices.
|
||||
@@ -59,6 +56,4 @@ void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) {
|
||||
|
||||
max_group_bins = cuts.TotalBins();
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
#include <thrust/reduce.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint> // uint32_t
|
||||
#include <limits>
|
||||
#include <cstdint> // uint32_t, int32_t
|
||||
|
||||
#include "../../collective/aggregator.h"
|
||||
#include "../../common/deterministic.cuh"
|
||||
@@ -128,7 +127,7 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest,
|
||||
}
|
||||
|
||||
template <int kBlockThreads, int kItemsPerThread,
|
||||
int kItemsPerTile = kBlockThreads* kItemsPerThread>
|
||||
int kItemsPerTile = kBlockThreads * kItemsPerThread>
|
||||
class HistogramAgent {
|
||||
GradientPairInt64* smem_arr_;
|
||||
GradientPairInt64* d_node_hist_;
|
||||
@@ -244,53 +243,82 @@ __global__ void __launch_bounds__(kBlockThreads)
|
||||
extern __shared__ char smem[];
|
||||
const FeatureGroup group = feature_groups[blockIdx.y];
|
||||
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
|
||||
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(
|
||||
smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair);
|
||||
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(smem_arr, d_node_hist, group, matrix,
|
||||
d_ridx, rounding, d_gpair);
|
||||
if (use_shared_memory_histograms) {
|
||||
agent.BuildHistogramWithShared();
|
||||
} else {
|
||||
agent.BuildHistogramWithGlobal();
|
||||
}
|
||||
}
|
||||
namespace {
|
||||
constexpr std::int32_t kBlockThreads = 1024;
|
||||
constexpr std::int32_t kItemsPerThread = 8;
|
||||
constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread; }
|
||||
} // namespace
|
||||
|
||||
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
||||
bool force_global_memory) {
|
||||
// decide whether to use shared memory
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
// opt into maximum shared memory for the kernel if necessary
|
||||
size_t max_shared_memory = dh::MaxSharedMemoryOptin(device);
|
||||
// Use auto deduction guide to workaround compiler error.
|
||||
template <auto Global = SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>,
|
||||
auto Shared = SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>>
|
||||
struct HistogramKernel {
|
||||
decltype(Global) global_kernel{SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>};
|
||||
decltype(Shared) shared_kernel{SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>};
|
||||
bool shared{false};
|
||||
std::uint32_t grid_size{0};
|
||||
std::size_t smem_size{0};
|
||||
|
||||
size_t smem_size =
|
||||
sizeof(GradientPairInt64) * feature_groups.max_group_bins;
|
||||
bool shared = !force_global_memory && smem_size <= max_shared_memory;
|
||||
smem_size = shared ? smem_size : 0;
|
||||
HistogramKernel(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||
bool force_global_memory) {
|
||||
// Decide whether to use shared memory
|
||||
// Opt into maximum shared memory for the kernel if necessary
|
||||
std::size_t max_shared_memory = dh::MaxSharedMemoryOptin(ctx->Ordinal());
|
||||
|
||||
constexpr int kBlockThreads = 1024;
|
||||
constexpr int kItemsPerThread = 8;
|
||||
constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread;
|
||||
this->smem_size = sizeof(GradientPairInt64) * feature_groups.max_group_bins;
|
||||
this->shared = !force_global_memory && smem_size <= max_shared_memory;
|
||||
this->smem_size = this->shared ? this->smem_size : 0;
|
||||
|
||||
auto runit = [&, kMinItemsPerBlock = kItemsPerTile](auto kernel) {
|
||||
if (shared) {
|
||||
dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_memory));
|
||||
}
|
||||
auto init = [&](auto& kernel) {
|
||||
if (this->shared) {
|
||||
dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_memory));
|
||||
}
|
||||
|
||||
// determine the launch configuration
|
||||
int num_groups = feature_groups.NumGroups();
|
||||
int n_mps = 0;
|
||||
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
||||
int n_blocks_per_mp = 0;
|
||||
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
|
||||
kBlockThreads, smem_size));
|
||||
// This gives the number of blocks to keep the device occupied
|
||||
// Use this as the maximum number of blocks
|
||||
unsigned grid_size = n_blocks_per_mp * n_mps;
|
||||
// determine the launch configuration
|
||||
std::int32_t num_groups = feature_groups.NumGroups();
|
||||
std::int32_t n_mps = 0;
|
||||
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, ctx->Ordinal()));
|
||||
|
||||
std::int32_t n_blocks_per_mp = 0;
|
||||
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
|
||||
kBlockThreads, this->smem_size));
|
||||
|
||||
// This gives the number of blocks to keep the device occupied Use this as the
|
||||
// maximum number of blocks
|
||||
this->grid_size = n_blocks_per_mp * n_mps;
|
||||
};
|
||||
|
||||
init(this->global_kernel);
|
||||
init(this->shared_kernel);
|
||||
}
|
||||
};
|
||||
|
||||
class DeviceHistogramBuilderImpl {
|
||||
std::unique_ptr<HistogramKernel<>> kernel_{nullptr};
|
||||
bool force_global_memory_{false};
|
||||
|
||||
public:
|
||||
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||
bool force_global_memory) {
|
||||
this->kernel_ = std::make_unique<HistogramKernel<>>(ctx, feature_groups, force_global_memory);
|
||||
this->force_global_memory_ = force_global_memory;
|
||||
}
|
||||
|
||||
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const std::uint32_t> d_ridx,
|
||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding) {
|
||||
CHECK(kernel_);
|
||||
// Otherwise launch blocks such that each block has a minimum amount of work to do
|
||||
// There are fixed costs to launching each block, e.g. zeroing shared memory
|
||||
// The below amount of minimum work was found by experimentation
|
||||
@@ -300,20 +328,41 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const&
|
||||
|
||||
// Allocate number of blocks such that each block has about kMinItemsPerBlock work
|
||||
// Up to a maximum where the device is saturated
|
||||
grid_size = std::min(grid_size, static_cast<std::uint32_t>(
|
||||
common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
|
||||
auto constexpr kMinItemsPerBlock = ItemsPerTile();
|
||||
auto grid_size = std::min(kernel_->grid_size, static_cast<std::uint32_t>(common::DivRoundUp(
|
||||
items_per_group, kMinItemsPerBlock)));
|
||||
|
||||
dh::LaunchKernel {dim3(grid_size, num_groups), static_cast<uint32_t>(kBlockThreads), smem_size,
|
||||
ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(),
|
||||
gpair.data(), rounding);
|
||||
};
|
||||
|
||||
if (shared) {
|
||||
runit(SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>);
|
||||
} else {
|
||||
runit(SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>);
|
||||
if (this->force_global_memory_ || !this->kernel_->shared) {
|
||||
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
|
||||
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size,
|
||||
ctx->Stream()}(kernel_->global_kernel, matrix, feature_groups, d_ridx,
|
||||
histogram.data(), gpair.data(), rounding);
|
||||
} else {
|
||||
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
|
||||
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size,
|
||||
ctx->Stream()}(kernel_->shared_kernel, matrix, feature_groups, d_ridx,
|
||||
histogram.data(), gpair.data(), rounding);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
DeviceHistogramBuilder::DeviceHistogramBuilder()
|
||||
: p_impl_{std::make_unique<DeviceHistogramBuilderImpl>()} {}
|
||||
|
||||
DeviceHistogramBuilder::~DeviceHistogramBuilder() = default;
|
||||
|
||||
void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||
bool force_global_memory) {
|
||||
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
|
||||
}
|
||||
|
||||
void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const std::uint32_t> ridx,
|
||||
common::Span<GradientPairInt64> histogram,
|
||||
GradientQuantiser rounding) {
|
||||
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding);
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef HISTOGRAM_CUH_
|
||||
#define HISTOGRAM_CUH_
|
||||
#include <thrust/transform.h>
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "../../common/cuda_context.cuh"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "feature_groups.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
#include "../../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../../data/ellpack_page.cuh" // for EllpackDeviceAccessor
|
||||
#include "feature_groups.cuh" // for FeatureGroupsAccessor
|
||||
#include "xgboost/base.h" // for GradientPair, GradientPairInt64
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::tree {
|
||||
/**
|
||||
* \brief An atomicAdd designed for gradient pair with better performance. For general
|
||||
* int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing.
|
||||
@@ -32,7 +33,7 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) {
|
||||
}
|
||||
|
||||
class GradientQuantiser {
|
||||
private:
|
||||
private:
|
||||
/* Convert gradient to fixed point representation. */
|
||||
GradientPairPrecise to_fixed_point_;
|
||||
/* Convert fixed point representation back to floating point. */
|
||||
@@ -59,13 +60,23 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
||||
bool force_global_memory = false);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
class DeviceHistogramBuilderImpl;
|
||||
|
||||
class DeviceHistogramBuilder {
|
||||
std::unique_ptr<DeviceHistogramBuilderImpl> p_impl_;
|
||||
|
||||
public:
|
||||
DeviceHistogramBuilder();
|
||||
~DeviceHistogramBuilder();
|
||||
|
||||
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||
bool force_global_memory);
|
||||
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const std::uint32_t> ridx,
|
||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
|
||||
#endif // HISTOGRAM_CUH_
|
||||
|
||||
@@ -162,6 +162,8 @@ struct GPUHistMakerDevice {
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
MetaInfo const& info_;
|
||||
|
||||
DeviceHistogramBuilder histogram_;
|
||||
|
||||
public:
|
||||
EllpackPageImpl const* page{nullptr};
|
||||
common::Span<FeatureType const> feature_types;
|
||||
@@ -256,6 +258,8 @@ struct GPUHistMakerDevice {
|
||||
hist.Reset();
|
||||
|
||||
this->InitFeatureGroupsOnce();
|
||||
|
||||
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
|
||||
}
|
||||
|
||||
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
||||
@@ -340,9 +344,9 @@ struct GPUHistMakerDevice {
|
||||
void BuildHist(int nidx) {
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
||||
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
||||
d_node_hist, *quantiser);
|
||||
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
||||
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
||||
d_node_hist, *quantiser);
|
||||
}
|
||||
|
||||
// Attempt to do subtraction trick
|
||||
|
||||
Reference in New Issue
Block a user