Cache GPU histogram kernel configuration. (#10538)
This commit is contained in:
parent
cd1d108c7d
commit
620b2b155a
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020 by XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
@ -8,12 +8,9 @@
|
|||||||
|
|
||||||
#include "feature_groups.cuh"
|
#include "feature_groups.cuh"
|
||||||
|
|
||||||
#include "../../common/device_helpers.cuh"
|
|
||||||
#include "../../common/hist_util.h"
|
#include "../../common/hist_util.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
|
FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
|
||||||
size_t shm_size, size_t bin_size) {
|
size_t shm_size, size_t bin_size) {
|
||||||
// Only use a single feature group for sparse matrices.
|
// Only use a single feature group for sparse matrices.
|
||||||
@ -59,6 +56,4 @@ void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) {
|
|||||||
|
|
||||||
max_group_bins = cuts.TotalBins();
|
max_group_bins = cuts.TotalBins();
|
||||||
}
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -5,8 +5,7 @@
|
|||||||
#include <thrust/reduce.h>
|
#include <thrust/reduce.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint> // uint32_t
|
#include <cstdint> // uint32_t, int32_t
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include "../../collective/aggregator.h"
|
#include "../../collective/aggregator.h"
|
||||||
#include "../../common/deterministic.cuh"
|
#include "../../common/deterministic.cuh"
|
||||||
@ -128,7 +127,7 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int kBlockThreads, int kItemsPerThread,
|
template <int kBlockThreads, int kItemsPerThread,
|
||||||
int kItemsPerTile = kBlockThreads* kItemsPerThread>
|
int kItemsPerTile = kBlockThreads * kItemsPerThread>
|
||||||
class HistogramAgent {
|
class HistogramAgent {
|
||||||
GradientPairInt64* smem_arr_;
|
GradientPairInt64* smem_arr_;
|
||||||
GradientPairInt64* d_node_hist_;
|
GradientPairInt64* d_node_hist_;
|
||||||
@ -244,53 +243,82 @@ __global__ void __launch_bounds__(kBlockThreads)
|
|||||||
extern __shared__ char smem[];
|
extern __shared__ char smem[];
|
||||||
const FeatureGroup group = feature_groups[blockIdx.y];
|
const FeatureGroup group = feature_groups[blockIdx.y];
|
||||||
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
|
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
|
||||||
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(
|
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(smem_arr, d_node_hist, group, matrix,
|
||||||
smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair);
|
d_ridx, rounding, d_gpair);
|
||||||
if (use_shared_memory_histograms) {
|
if (use_shared_memory_histograms) {
|
||||||
agent.BuildHistogramWithShared();
|
agent.BuildHistogramWithShared();
|
||||||
} else {
|
} else {
|
||||||
agent.BuildHistogramWithGlobal();
|
agent.BuildHistogramWithGlobal();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
namespace {
|
||||||
|
constexpr std::int32_t kBlockThreads = 1024;
|
||||||
|
constexpr std::int32_t kItemsPerThread = 8;
|
||||||
|
constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread; }
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
// Use auto deduction guide to workaround compiler error.
|
||||||
FeatureGroupsAccessor const& feature_groups,
|
template <auto Global = SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>,
|
||||||
common::Span<GradientPair const> gpair,
|
auto Shared = SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>>
|
||||||
common::Span<const uint32_t> d_ridx,
|
struct HistogramKernel {
|
||||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
decltype(Global) global_kernel{SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>};
|
||||||
|
decltype(Shared) shared_kernel{SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>};
|
||||||
|
bool shared{false};
|
||||||
|
std::uint32_t grid_size{0};
|
||||||
|
std::size_t smem_size{0};
|
||||||
|
|
||||||
|
HistogramKernel(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||||
bool force_global_memory) {
|
bool force_global_memory) {
|
||||||
// decide whether to use shared memory
|
// Decide whether to use shared memory
|
||||||
int device = 0;
|
// Opt into maximum shared memory for the kernel if necessary
|
||||||
dh::safe_cuda(cudaGetDevice(&device));
|
std::size_t max_shared_memory = dh::MaxSharedMemoryOptin(ctx->Ordinal());
|
||||||
// opt into maximum shared memory for the kernel if necessary
|
|
||||||
size_t max_shared_memory = dh::MaxSharedMemoryOptin(device);
|
|
||||||
|
|
||||||
size_t smem_size =
|
this->smem_size = sizeof(GradientPairInt64) * feature_groups.max_group_bins;
|
||||||
sizeof(GradientPairInt64) * feature_groups.max_group_bins;
|
this->shared = !force_global_memory && smem_size <= max_shared_memory;
|
||||||
bool shared = !force_global_memory && smem_size <= max_shared_memory;
|
this->smem_size = this->shared ? this->smem_size : 0;
|
||||||
smem_size = shared ? smem_size : 0;
|
|
||||||
|
|
||||||
constexpr int kBlockThreads = 1024;
|
auto init = [&](auto& kernel) {
|
||||||
constexpr int kItemsPerThread = 8;
|
if (this->shared) {
|
||||||
constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread;
|
|
||||||
|
|
||||||
auto runit = [&, kMinItemsPerBlock = kItemsPerTile](auto kernel) {
|
|
||||||
if (shared) {
|
|
||||||
dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
max_shared_memory));
|
max_shared_memory));
|
||||||
}
|
}
|
||||||
|
|
||||||
// determine the launch configuration
|
// determine the launch configuration
|
||||||
int num_groups = feature_groups.NumGroups();
|
std::int32_t num_groups = feature_groups.NumGroups();
|
||||||
int n_mps = 0;
|
std::int32_t n_mps = 0;
|
||||||
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, ctx->Ordinal()));
|
||||||
int n_blocks_per_mp = 0;
|
|
||||||
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
|
|
||||||
kBlockThreads, smem_size));
|
|
||||||
// This gives the number of blocks to keep the device occupied
|
|
||||||
// Use this as the maximum number of blocks
|
|
||||||
unsigned grid_size = n_blocks_per_mp * n_mps;
|
|
||||||
|
|
||||||
|
std::int32_t n_blocks_per_mp = 0;
|
||||||
|
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
|
||||||
|
kBlockThreads, this->smem_size));
|
||||||
|
|
||||||
|
// This gives the number of blocks to keep the device occupied Use this as the
|
||||||
|
// maximum number of blocks
|
||||||
|
this->grid_size = n_blocks_per_mp * n_mps;
|
||||||
|
};
|
||||||
|
|
||||||
|
init(this->global_kernel);
|
||||||
|
init(this->shared_kernel);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DeviceHistogramBuilderImpl {
|
||||||
|
std::unique_ptr<HistogramKernel<>> kernel_{nullptr};
|
||||||
|
bool force_global_memory_{false};
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||||
|
bool force_global_memory) {
|
||||||
|
this->kernel_ = std::make_unique<HistogramKernel<>>(ctx, feature_groups, force_global_memory);
|
||||||
|
this->force_global_memory_ = force_global_memory;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||||
|
FeatureGroupsAccessor const& feature_groups,
|
||||||
|
common::Span<GradientPair const> gpair,
|
||||||
|
common::Span<const std::uint32_t> d_ridx,
|
||||||
|
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding) {
|
||||||
|
CHECK(kernel_);
|
||||||
// Otherwise launch blocks such that each block has a minimum amount of work to do
|
// Otherwise launch blocks such that each block has a minimum amount of work to do
|
||||||
// There are fixed costs to launching each block, e.g. zeroing shared memory
|
// There are fixed costs to launching each block, e.g. zeroing shared memory
|
||||||
// The below amount of minimum work was found by experimentation
|
// The below amount of minimum work was found by experimentation
|
||||||
@ -300,20 +328,41 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const&
|
|||||||
|
|
||||||
// Allocate number of blocks such that each block has about kMinItemsPerBlock work
|
// Allocate number of blocks such that each block has about kMinItemsPerBlock work
|
||||||
// Up to a maximum where the device is saturated
|
// Up to a maximum where the device is saturated
|
||||||
grid_size = std::min(grid_size, static_cast<std::uint32_t>(
|
auto constexpr kMinItemsPerBlock = ItemsPerTile();
|
||||||
common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
|
auto grid_size = std::min(kernel_->grid_size, static_cast<std::uint32_t>(common::DivRoundUp(
|
||||||
|
items_per_group, kMinItemsPerBlock)));
|
||||||
|
|
||||||
dh::LaunchKernel {dim3(grid_size, num_groups), static_cast<uint32_t>(kBlockThreads), smem_size,
|
if (this->force_global_memory_ || !this->kernel_->shared) {
|
||||||
ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(),
|
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
|
||||||
gpair.data(), rounding);
|
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size,
|
||||||
};
|
ctx->Stream()}(kernel_->global_kernel, matrix, feature_groups, d_ridx,
|
||||||
|
histogram.data(), gpair.data(), rounding);
|
||||||
if (shared) {
|
|
||||||
runit(SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>);
|
|
||||||
} else {
|
} else {
|
||||||
runit(SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>);
|
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
|
||||||
|
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size,
|
||||||
|
ctx->Stream()}(kernel_->shared_kernel, matrix, feature_groups, d_ridx,
|
||||||
|
histogram.data(), gpair.data(), rounding);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
dh::safe_cuda(cudaGetLastError());
|
DeviceHistogramBuilder::DeviceHistogramBuilder()
|
||||||
|
: p_impl_{std::make_unique<DeviceHistogramBuilderImpl>()} {}
|
||||||
|
|
||||||
|
DeviceHistogramBuilder::~DeviceHistogramBuilder() = default;
|
||||||
|
|
||||||
|
void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||||
|
bool force_global_memory) {
|
||||||
|
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
|
||||||
|
EllpackDeviceAccessor const& matrix,
|
||||||
|
FeatureGroupsAccessor const& feature_groups,
|
||||||
|
common::Span<GradientPair const> gpair,
|
||||||
|
common::Span<const std::uint32_t> ridx,
|
||||||
|
common::Span<GradientPairInt64> histogram,
|
||||||
|
GradientQuantiser rounding) {
|
||||||
|
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020-2021 by XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef HISTOGRAM_CUH_
|
#ifndef HISTOGRAM_CUH_
|
||||||
#define HISTOGRAM_CUH_
|
#define HISTOGRAM_CUH_
|
||||||
#include <thrust/transform.h>
|
#include <memory> // for unique_ptr
|
||||||
|
|
||||||
#include "../../common/cuda_context.cuh"
|
#include "../../common/cuda_context.cuh" // for CUDAContext
|
||||||
#include "../../data/ellpack_page.cuh"
|
#include "../../data/ellpack_page.cuh" // for EllpackDeviceAccessor
|
||||||
#include "feature_groups.cuh"
|
#include "feature_groups.cuh" // for FeatureGroupsAccessor
|
||||||
|
#include "xgboost/base.h" // for GradientPair, GradientPairInt64
|
||||||
namespace xgboost {
|
#include "xgboost/context.h" // for Context
|
||||||
namespace tree {
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::tree {
|
||||||
/**
|
/**
|
||||||
* \brief An atomicAdd designed for gradient pair with better performance. For general
|
* \brief An atomicAdd designed for gradient pair with better performance. For general
|
||||||
* int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing.
|
* int64_t atomicAdd, one can simply cast it to unsigned long long. Exposed for testing.
|
||||||
@ -32,7 +33,7 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class GradientQuantiser {
|
class GradientQuantiser {
|
||||||
private:
|
private:
|
||||||
/* Convert gradient to fixed point representation. */
|
/* Convert gradient to fixed point representation. */
|
||||||
GradientPairPrecise to_fixed_point_;
|
GradientPairPrecise to_fixed_point_;
|
||||||
/* Convert fixed point representation back to floating point. */
|
/* Convert fixed point representation back to floating point. */
|
||||||
@ -59,13 +60,23 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
class DeviceHistogramBuilderImpl;
|
||||||
|
|
||||||
|
class DeviceHistogramBuilder {
|
||||||
|
std::unique_ptr<DeviceHistogramBuilderImpl> p_impl_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
DeviceHistogramBuilder();
|
||||||
|
~DeviceHistogramBuilder();
|
||||||
|
|
||||||
|
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
|
||||||
|
bool force_global_memory);
|
||||||
|
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||||
FeatureGroupsAccessor const& feature_groups,
|
FeatureGroupsAccessor const& feature_groups,
|
||||||
common::Span<GradientPair const> gpair,
|
common::Span<GradientPair const> gpair,
|
||||||
common::Span<const uint32_t> ridx,
|
common::Span<const std::uint32_t> ridx,
|
||||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
|
||||||
bool force_global_memory = false);
|
};
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
#endif // HISTOGRAM_CUH_
|
#endif // HISTOGRAM_CUH_
|
||||||
|
|||||||
@ -162,6 +162,8 @@ struct GPUHistMakerDevice {
|
|||||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
MetaInfo const& info_;
|
MetaInfo const& info_;
|
||||||
|
|
||||||
|
DeviceHistogramBuilder histogram_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EllpackPageImpl const* page{nullptr};
|
EllpackPageImpl const* page{nullptr};
|
||||||
common::Span<FeatureType const> feature_types;
|
common::Span<FeatureType const> feature_types;
|
||||||
@ -256,6 +258,8 @@ struct GPUHistMakerDevice {
|
|||||||
hist.Reset();
|
hist.Reset();
|
||||||
|
|
||||||
this->InitFeatureGroupsOnce();
|
this->InitFeatureGroupsOnce();
|
||||||
|
|
||||||
|
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
||||||
@ -340,7 +344,7 @@ struct GPUHistMakerDevice {
|
|||||||
void BuildHist(int nidx) {
|
void BuildHist(int nidx) {
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||||
BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
||||||
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
||||||
d_node_hist, *quantiser);
|
d_node_hist, *quantiser);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2020-2023, XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../../../src/common/categorical.h"
|
|
||||||
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
||||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||||
#include "../../../../src/tree/param.h" // TrainParam
|
#include "../../../../src/tree/param.h" // TrainParam
|
||||||
@ -13,7 +12,7 @@
|
|||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) {
|
||||||
Context ctx = MakeCUDACtx(0);
|
Context ctx = MakeCUDACtx(0);
|
||||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||||
@ -25,35 +24,37 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
|
|
||||||
tree::RowPartitioner row_partitioner(FstCU(), kRows);
|
tree::RowPartitioner row_partitioner(ctx.Device(), kRows);
|
||||||
auto ridx = row_partitioner.GetRows(0);
|
auto ridx = row_partitioner.GetRows(0);
|
||||||
|
|
||||||
int num_bins = kBins * kCols;
|
bst_bin_t num_bins = kBins * kCols;
|
||||||
dh::device_vector<GradientPairInt64> histogram(num_bins);
|
dh::device_vector<GradientPairInt64> histogram(num_bins);
|
||||||
auto d_histogram = dh::ToSpan(histogram);
|
auto d_histogram = dh::ToSpan(histogram);
|
||||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||||
gpair.SetDevice(FstCU());
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64));
|
||||||
sizeof(GradientPairInt64));
|
|
||||||
|
|
||||||
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
DeviceHistogramBuilder builder;
|
||||||
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
|
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||||
d_histogram, quantiser);
|
d_histogram, quantiser);
|
||||||
|
|
||||||
std::vector<GradientPairInt64> histogram_h(num_bins);
|
std::vector<GradientPairInt64> histogram_h(num_bins);
|
||||||
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
||||||
num_bins * sizeof(GradientPairInt64),
|
num_bins * sizeof(GradientPairInt64), cudaMemcpyDeviceToHost));
|
||||||
cudaMemcpyDeviceToHost));
|
|
||||||
|
|
||||||
for (size_t i = 0; i < kRounds; ++i) {
|
for (std::size_t i = 0; i < kRounds; ++i) {
|
||||||
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
||||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||||
|
|
||||||
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
DeviceHistogramBuilder builder;
|
||||||
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
|
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||||
d_new_histogram, quantiser);
|
d_new_histogram, quantiser);
|
||||||
|
|
||||||
std::vector<GradientPairInt64> new_histogram_h(num_bins);
|
std::vector<GradientPairInt64> new_histogram_h(num_bins);
|
||||||
@ -68,14 +69,16 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||||
gpair.SetDevice(FstCU());
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
// Use a single feature group to compute the baseline.
|
// Use a single feature group to compute the baseline.
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
|
|
||||||
dh::device_vector<GradientPairInt64> baseline(num_bins);
|
dh::device_vector<GradientPairInt64> baseline(num_bins);
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
DeviceHistogramBuilder builder;
|
||||||
single_group.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
|
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||||
dh::ToSpan(baseline), quantiser);
|
dh::ToSpan(baseline), quantiser);
|
||||||
|
|
||||||
std::vector<GradientPairInt64> baseline_h(num_bins);
|
std::vector<GradientPairInt64> baseline_h(num_bins);
|
||||||
@ -96,7 +99,9 @@ TEST(Histogram, GPUDeterministic) {
|
|||||||
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
|
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
|
||||||
for (bool is_dense : is_dense_array) {
|
for (bool is_dense : is_dense_array) {
|
||||||
for (int shm_size : shm_sizes) {
|
for (int shm_size : shm_sizes) {
|
||||||
TestDeterministicHistogram(is_dense, shm_size);
|
for (bool force_global : {true, false}) {
|
||||||
|
TestDeterministicHistogram(is_dense, shm_size, force_global);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,7 +141,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
for (auto const &batch : cat_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
for (auto const &batch : cat_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
DeviceHistogramBuilder builder;
|
||||||
|
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||||
dh::ToSpan(cat_hist), quantiser);
|
dh::ToSpan(cat_hist), quantiser);
|
||||||
}
|
}
|
||||||
@ -150,7 +157,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
for (auto const &batch : encode_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
for (auto const &batch : encode_m->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
DeviceHistogramBuilder builder;
|
||||||
|
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||||
dh::ToSpan(encode_hist), quantiser);
|
dh::ToSpan(encode_hist), quantiser);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2017-2023 by XGBoost contributors
|
* Copyright 2017-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
@ -22,12 +22,8 @@
|
|||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
|
||||||
#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal
|
|
||||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
TEST(GpuHist, DeviceHistogram) {
|
TEST(GpuHist, DeviceHistogramStorage) {
|
||||||
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
|
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
|
||||||
dh::safe_cuda(cudaSetDevice(0));
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
constexpr size_t kNBins = 128;
|
constexpr size_t kNBins = 128;
|
||||||
@ -102,17 +98,17 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
HostDeviceVector<GradientPair> gpair(kNRows);
|
HostDeviceVector<GradientPair> gpair(kNRows);
|
||||||
for (auto &gp : gpair.HostVector()) {
|
for (auto& gp : gpair.HostVector()) {
|
||||||
bst_float grad = dist(&gen);
|
float grad = dist(&gen);
|
||||||
bst_float hess = dist(&gen);
|
float hess = dist(&gen);
|
||||||
gp = GradientPair(grad, hess);
|
gp = GradientPair{grad, hess};
|
||||||
}
|
}
|
||||||
gpair.SetDevice(DeviceOrd::CUDA(0));
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
|
thrust::host_vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.HostVector());
|
||||||
maker.row_partitioner = std::make_unique<RowPartitioner>(FstCU(), kNRows);
|
maker.row_partitioner = std::make_unique<RowPartitioner>(ctx.Device(), kNRows);
|
||||||
|
|
||||||
maker.hist.Init(FstCU(), page->Cuts().TotalBins());
|
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
||||||
maker.hist.AllocateHistograms({0});
|
maker.hist.AllocateHistograms({0});
|
||||||
|
|
||||||
maker.gpair = gpair.DeviceSpan();
|
maker.gpair = gpair.DeviceSpan();
|
||||||
@ -121,10 +117,13 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
|
|
||||||
maker.InitFeatureGroupsOnce();
|
maker.InitFeatureGroupsOnce();
|
||||||
|
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(DeviceOrd::CUDA(0)),
|
DeviceHistogramBuilder builder;
|
||||||
maker.feature_groups->DeviceAccessor(DeviceOrd::CUDA(0)), gpair.DeviceSpan(),
|
builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()),
|
||||||
|
!use_shared_memory_histograms);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||||
|
maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
|
||||||
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
|
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
|
||||||
*maker.quantiser, !use_shared_memory_histograms);
|
*maker.quantiser);
|
||||||
|
|
||||||
DeviceHistogramStorage<>& d_hist = maker.hist;
|
DeviceHistogramStorage<>& d_hist = maker.hist;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user