[breaking] Use integer atomic for GPU histogram. (#7180)
On GPU we use rouding factor to truncate the gradient for deterministic results. This PR changes the gradient representation to fixed point number with exponent aligned with rounding factor.
[breaking] Drop non-deterministic histogram.
Use fixed point for shared memory.
This PR is to improve the performance of GPU Hist.
Co-authored-by: Andy Adinets <aadinets@nvidia.com>
This commit is contained in:
parent
e7d7ab6bc3
commit
7a1d67f9cb
@ -245,16 +245,6 @@ Additional parameters for ``hist`` and ``gpu_hist`` tree method
|
||||
|
||||
- Use single precision to build histograms instead of double precision.
|
||||
|
||||
Additional parameters for ``gpu_hist`` tree method
|
||||
==================================================
|
||||
|
||||
* ``deterministic_histogram``, [default=``true``]
|
||||
|
||||
- Build histogram on GPU deterministically. Histogram building is not deterministic due
|
||||
to the non-associative aspect of floating point summation. We employ a pre-rounding
|
||||
routine to mitigate the issue, which may lead to slightly lower accuracy. Set to
|
||||
``false`` to disable it.
|
||||
|
||||
Additional parameters for Dart Booster (``booster=dart``)
|
||||
=========================================================
|
||||
|
||||
|
||||
@ -255,9 +255,12 @@ class GradientPairInternal {
|
||||
|
||||
/*! \brief gradient statistics pair usually needed in gradient boosting */
|
||||
using GradientPair = detail::GradientPairInternal<float>;
|
||||
|
||||
/*! \brief High precision gradient statistics pair */
|
||||
using GradientPairPrecise = detail::GradientPairInternal<double>;
|
||||
/*! \brief Fixed point representation for gradient pair. */
|
||||
using GradientPairInt32 = detail::GradientPairInternal<int>;
|
||||
/*! \brief Fixed point representation for high precision gradient pair. */
|
||||
using GradientPairInt64 = detail::GradientPairInternal<int64_t>;
|
||||
|
||||
using Args = std::vector<std::pair<std::string, std::string> >;
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 XGBoost contributors
|
||||
* Copyright 2017-2021 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <thrust/device_ptr.h>
|
||||
@ -53,27 +53,6 @@
|
||||
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
|
||||
|
||||
#else // In device code and CUDA < 600
|
||||
__device__ __forceinline__ double atomicAdd(double* address, double val) { // NOLINT
|
||||
unsigned long long int* address_as_ull =
|
||||
(unsigned long long int*)address; // NOLINT
|
||||
unsigned long long int old = *address_as_ull, assumed; // NOLINT
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
|
||||
// NaN)
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace dh {
|
||||
namespace detail {
|
||||
template <size_t size>
|
||||
@ -98,12 +77,11 @@ template <typename T = size_t,
|
||||
std::enable_if_t<std::is_same<size_t, T>::value &&
|
||||
!std::is_same<size_t, unsigned long long>::value> * = // NOLINT
|
||||
nullptr>
|
||||
T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT
|
||||
XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT
|
||||
using Type = typename dh::detail::AtomicDispatcher<sizeof(T)>::Type;
|
||||
Type ret = ::atomicAdd(reinterpret_cast<Type *>(addr), static_cast<Type>(v));
|
||||
return static_cast<T>(ret);
|
||||
}
|
||||
|
||||
namespace dh {
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
@ -1109,6 +1087,44 @@ XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief An atomicAdd designed for gradient pair with better performance. For general
|
||||
* int64_t atomicAdd, one can simply cast it to unsigned long long.
|
||||
*/
|
||||
XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t *dst, int64_t src) {
|
||||
uint32_t* y_low = reinterpret_cast<uint32_t *>(dst);
|
||||
uint32_t *y_high = y_low + 1;
|
||||
|
||||
auto cast_src = reinterpret_cast<uint64_t *>(&src);
|
||||
|
||||
uint32_t const x_low = static_cast<uint32_t>(src);
|
||||
uint32_t const x_high = (*cast_src) >> 32;
|
||||
|
||||
auto const old = atomicAdd(y_low, x_low);
|
||||
uint32_t const carry = old > (std::numeric_limits<uint32_t>::max() - x_low) ? 1 : 0;
|
||||
uint32_t const sig = x_high + carry;
|
||||
atomicAdd(y_high, sig);
|
||||
}
|
||||
|
||||
XGBOOST_DEV_INLINE void
|
||||
AtomicAddGpair(xgboost::GradientPairInt64 *dest,
|
||||
xgboost::GradientPairInt64 const &gpair) {
|
||||
auto dst_ptr = reinterpret_cast<int64_t *>(dest);
|
||||
auto g = gpair.GetGrad();
|
||||
auto h = gpair.GetHess();
|
||||
|
||||
AtomicAdd64As32(dst_ptr, g);
|
||||
AtomicAdd64As32(dst_ptr + 1, h);
|
||||
}
|
||||
|
||||
XGBOOST_DEV_INLINE void
|
||||
AtomicAddGpair(xgboost::GradientPairInt32 *dest,
|
||||
xgboost::GradientPairInt32 const &gpair) {
|
||||
auto dst_ptr = reinterpret_cast<typename xgboost::GradientPairInt32::ValueT*>(dest);
|
||||
|
||||
::atomicAdd(dst_ptr, static_cast<int>(gpair.GetGrad()));
|
||||
::atomicAdd(dst_ptr + 1, static_cast<int>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
@ -34,7 +34,7 @@ namespace tree {
|
||||
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
||||
*/
|
||||
template <typename T>
|
||||
XGBOOST_DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) {
|
||||
T CreateRoundingFactor(T max_abs, int n) {
|
||||
T delta = max_abs / (static_cast<T>(1.0) - 2 * n * std::numeric_limits<T>::epsilon());
|
||||
|
||||
// Calculate ceil(log_2(delta)).
|
||||
@ -78,7 +78,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
|
||||
HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const> gpair) {
|
||||
using T = typename GradientSumT::ValueT;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
@ -94,26 +94,51 @@ GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
|
||||
gpair.size()),
|
||||
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
|
||||
gpair.size()) };
|
||||
return histogram_rounding;
|
||||
|
||||
using IntT = typename HistRounding<GradientSumT>::SharedSumT::ValueT;
|
||||
|
||||
/**
|
||||
* Factor for converting gradients from fixed-point to floating-point.
|
||||
*/
|
||||
GradientSumT to_floating_point =
|
||||
histogram_rounding /
|
||||
T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 -
|
||||
2)); // keep 1 for sign bit
|
||||
/**
|
||||
* Factor for converting gradients from floating-point to fixed-point. For
|
||||
* f64:
|
||||
*
|
||||
* Precision = 64 - 1 - log2(rounding)
|
||||
*
|
||||
* rounding is calcuated as exp(m), see the rounding factor calcuation for
|
||||
* details.
|
||||
*/
|
||||
GradientSumT to_fixed_point = GradientSumT(
|
||||
T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess());
|
||||
|
||||
return {histogram_rounding, to_fixed_point, to_floating_point};
|
||||
}
|
||||
|
||||
template GradientPairPrecise CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
template HistRounding<GradientPairPrecise>
|
||||
CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
template HistRounding<GradientPair>
|
||||
CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
|
||||
template <typename GradientSumT>
|
||||
template <typename GradientSumT, bool use_shared_memory_histograms>
|
||||
__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
|
||||
FeatureGroupsAccessor feature_groups,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
GradientSumT* __restrict__ d_node_hist,
|
||||
const GradientPair* __restrict__ d_gpair,
|
||||
GradientSumT const rounding,
|
||||
bool use_shared_memory_histograms) {
|
||||
HistRounding<GradientSumT> const rounding) {
|
||||
using SharedSumT = typename HistRounding<GradientSumT>::SharedSumT;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
|
||||
extern __shared__ char smem[];
|
||||
FeatureGroup group = feature_groups[blockIdx.y];
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
SharedSumT *smem_arr = reinterpret_cast<SharedSumT *>(smem);
|
||||
if (use_shared_memory_histograms) {
|
||||
dh::BlockFill(smem_arr, group.num_bins, GradientSumT());
|
||||
dh::BlockFill(smem_arr, group.num_bins, SharedSumT());
|
||||
__syncthreads();
|
||||
}
|
||||
int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride;
|
||||
@ -123,16 +148,21 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
|
||||
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature +
|
||||
idx % feature_stride];
|
||||
if (gidx != matrix.NumBins()) {
|
||||
GradientSumT truncated {
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
|
||||
TruncateWithRoundingFactor<T>(rounding.GetHess(), d_gpair[ridx].GetHess()),
|
||||
};
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
GradientSumT* atomic_add_ptr =
|
||||
use_shared_memory_histograms ? smem_arr : d_node_hist;
|
||||
gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx;
|
||||
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
|
||||
if (use_shared_memory_histograms) {
|
||||
auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]);
|
||||
dh::AtomicAddGpair(smem_arr + gidx, adjusted);
|
||||
} else {
|
||||
GradientSumT truncated{
|
||||
TruncateWithRoundingFactor<T>(rounding.rounding.GetGrad(),
|
||||
d_gpair[ridx].GetGrad()),
|
||||
TruncateWithRoundingFactor<T>(rounding.rounding.GetHess(),
|
||||
d_gpair[ridx].GetHess()),
|
||||
};
|
||||
dh::AtomicAddGpair(d_node_hist + gidx, truncated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,12 +170,7 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
|
||||
// Write shared memory back to global memory
|
||||
__syncthreads();
|
||||
for (auto i : dh::BlockStrideRange(0, group.num_bins)) {
|
||||
GradientSumT truncated{
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(),
|
||||
smem_arr[i].GetGrad()),
|
||||
TruncateWithRoundingFactor<T>(rounding.GetHess(),
|
||||
smem_arr[i].GetHess()),
|
||||
};
|
||||
auto truncated = rounding.ToFloatingPoint(smem_arr[i]);
|
||||
dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated);
|
||||
}
|
||||
}
|
||||
@ -157,57 +182,68 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
GradientSumT rounding) {
|
||||
HistRounding<GradientSumT> rounding,
|
||||
bool force_global_memory) {
|
||||
// decide whether to use shared memory
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
// opt into maximum shared memory for the kernel if necessary
|
||||
int max_shared_memory = dh::MaxSharedMemoryOptin(device);
|
||||
size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins;
|
||||
bool shared = smem_size <= max_shared_memory;
|
||||
|
||||
size_t smem_size = sizeof(typename HistRounding<GradientSumT>::SharedSumT) *
|
||||
feature_groups.max_group_bins;
|
||||
bool shared = !force_global_memory && smem_size <= max_shared_memory;
|
||||
smem_size = shared ? smem_size : 0;
|
||||
|
||||
// opt into maximum shared memory for the kernel if necessary
|
||||
auto kernel = SharedMemHistKernel<GradientSumT>;
|
||||
auto runit = [&](auto kernel) {
|
||||
if (shared) {
|
||||
dh::safe_cuda(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_memory));
|
||||
}
|
||||
|
||||
// determine the launch configuration
|
||||
int min_grid_size;
|
||||
int block_threads = 1024;
|
||||
dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize(
|
||||
&min_grid_size, &block_threads, kernel, smem_size, 0));
|
||||
|
||||
int num_groups = feature_groups.NumGroups();
|
||||
int n_mps = 0;
|
||||
dh::safe_cuda(
|
||||
cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
||||
int n_blocks_per_mp = 0;
|
||||
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&n_blocks_per_mp, kernel, block_threads, smem_size));
|
||||
unsigned grid_size = n_blocks_per_mp * n_mps;
|
||||
|
||||
// TODO(canonizer): This is really a hack, find a better way to distribute
|
||||
// the data among thread blocks. The intention is to generate enough thread
|
||||
// blocks to fill the GPU, but avoid having too many thread blocks, as this
|
||||
// is less efficient when the number of rows is low. At least one thread
|
||||
// block per feature group is required. The number of thread blocks:
|
||||
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
|
||||
// - for num_groups_threshold <= num_groups <= num_groups_threshold *
|
||||
// grid_size,
|
||||
// around grid_size * num_groups_threshold
|
||||
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
|
||||
int num_groups_threshold = 4;
|
||||
grid_size = common::DivRoundUp(
|
||||
grid_size, common::DivRoundUp(num_groups, num_groups_threshold));
|
||||
|
||||
using T = typename GradientSumT::ValueT;
|
||||
dh::LaunchKernel {dim3(grid_size, num_groups),
|
||||
static_cast<uint32_t>(block_threads),
|
||||
smem_size} (kernel, matrix, feature_groups, d_ridx,
|
||||
histogram.data(), gpair.data(), rounding);
|
||||
};
|
||||
|
||||
if (shared) {
|
||||
dh::safe_cuda(cudaFuncSetAttribute
|
||||
(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_memory));
|
||||
runit(SharedMemHistKernel<GradientSumT, true>);
|
||||
} else {
|
||||
runit(SharedMemHistKernel<GradientSumT, false>);
|
||||
}
|
||||
|
||||
// determine the launch configuration
|
||||
int min_grid_size;
|
||||
int block_threads = 1024;
|
||||
dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize(
|
||||
&min_grid_size, &block_threads, kernel, smem_size, 0));
|
||||
|
||||
int num_groups = feature_groups.NumGroups();
|
||||
int n_mps = 0;
|
||||
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
||||
int n_blocks_per_mp = 0;
|
||||
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
(&n_blocks_per_mp, kernel, block_threads, smem_size));
|
||||
unsigned grid_size = n_blocks_per_mp * n_mps;
|
||||
|
||||
// TODO(canonizer): This is really a hack, find a better way to distribute the
|
||||
// data among thread blocks.
|
||||
// The intention is to generate enough thread blocks to fill the GPU, but
|
||||
// avoid having too many thread blocks, as this is less efficient when the
|
||||
// number of rows is low. At least one thread block per feature group is
|
||||
// required.
|
||||
// The number of thread blocks:
|
||||
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
|
||||
// - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size,
|
||||
// around grid_size * num_groups_threshold
|
||||
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
|
||||
int num_groups_threshold = 4;
|
||||
grid_size = common::DivRoundUp(grid_size,
|
||||
common::DivRoundUp(num_groups, num_groups_threshold));
|
||||
|
||||
dh::LaunchKernel {
|
||||
dim3(grid_size, num_groups), static_cast<uint32_t>(block_threads), smem_size} (
|
||||
kernel,
|
||||
matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding,
|
||||
shared);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
@ -217,7 +253,8 @@ template void BuildGradientHistogram<GradientPair>(
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPair> histogram,
|
||||
GradientPair rounding);
|
||||
HistRounding<GradientPair> rounding,
|
||||
bool force_global_memory);
|
||||
|
||||
template void BuildGradientHistogram<GradientPairPrecise>(
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
@ -225,7 +262,8 @@ template void BuildGradientHistogram<GradientPairPrecise>(
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairPrecise> histogram,
|
||||
GradientPairPrecise rounding);
|
||||
HistRounding<GradientPairPrecise> rounding,
|
||||
bool force_global_memory);
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -12,22 +12,57 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename GradientSumT>
|
||||
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
|
||||
template <typename T, typename U>
|
||||
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) {
|
||||
static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision.");
|
||||
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncation factor for gradient, see comments in `CreateRoundingFactor()` for details.
|
||||
*/
|
||||
template <typename GradientSumT>
|
||||
struct HistRounding {
|
||||
/* Factor to truncate the gradient before building histogram for deterministic result. */
|
||||
GradientSumT rounding;
|
||||
/* Convert gradient to fixed point representation. */
|
||||
GradientSumT to_fixed_point;
|
||||
/* Convert fixed point representation back to floating point. */
|
||||
GradientSumT to_floating_point;
|
||||
|
||||
/* Type used in shared memory. */
|
||||
using SharedSumT = std::conditional_t<
|
||||
std::is_same<typename GradientSumT::ValueT, float>::value,
|
||||
GradientPairInt32, GradientPairInt64>;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
|
||||
XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const {
|
||||
auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()),
|
||||
T(gpair.GetHess() * to_fixed_point.GetHess()));
|
||||
return adjusted;
|
||||
}
|
||||
XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const {
|
||||
auto g = gpair.GetGrad() * to_floating_point.GetGrad();
|
||||
auto h = gpair.GetHess() * to_floating_point.GetHess();
|
||||
GradientSumT truncated{
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(), g),
|
||||
TruncateWithRoundingFactor<T>(rounding.GetHess(), h),
|
||||
};
|
||||
return truncated;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
GradientSumT rounding);
|
||||
HistRounding<GradientSumT> rounding,
|
||||
bool force_global_memory = false);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -46,14 +46,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
struct GPUHistMakerTrainParam
|
||||
: public XGBoostParameter<GPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram;
|
||||
bool deterministic_histogram;
|
||||
bool debug_synchronize;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||
"Use single precision to build histograms.");
|
||||
DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe(
|
||||
"Pre-round the gradient for obtaining deterministic gradient histogram.");
|
||||
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
|
||||
"Check if all distributed tree are identical after tree construction.");
|
||||
}
|
||||
@ -153,7 +150,7 @@ class DeviceHistogram {
|
||||
*/
|
||||
common::Span<GradientSumT> GetNodeHistogram(int nidx) {
|
||||
CHECK(this->HistogramExists(nidx));
|
||||
auto ptr = data_.data().get() + nidx_map_[nidx];
|
||||
auto ptr = data_.data().get() + nidx_map_.at(nidx);
|
||||
return common::Span<GradientSumT>(
|
||||
reinterpret_cast<GradientSumT*>(ptr), n_bins_);
|
||||
}
|
||||
@ -179,9 +176,8 @@ struct GPUHistMakerDevice {
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
|
||||
TrainParam param;
|
||||
bool deterministic_histogram;
|
||||
|
||||
GradientSumT histogram_rounding;
|
||||
HistRounding<GradientSumT> histogram_rounding;
|
||||
|
||||
dh::PinnedMemory pinned;
|
||||
|
||||
@ -205,7 +201,6 @@ struct GPUHistMakerDevice {
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
uint32_t n_features,
|
||||
bool deterministic_histogram,
|
||||
BatchParam _batch_param)
|
||||
: device_id(_device_id),
|
||||
page(_page),
|
||||
@ -214,7 +209,6 @@ struct GPUHistMakerDevice {
|
||||
tree_evaluator(param, n_features, _device_id),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features),
|
||||
deterministic_histogram{deterministic_histogram},
|
||||
batch_param(_batch_param) {
|
||||
sampler.reset(new GradientBasedSampler(
|
||||
page, _n_rows, batch_param, param.subsample, param.sampling_method));
|
||||
@ -227,9 +221,9 @@ struct GPUHistMakerDevice {
|
||||
// Init histogram
|
||||
hist.Init(device_id, page->Cuts().TotalBins());
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
feature_groups.reset(new FeatureGroups(
|
||||
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id),
|
||||
sizeof(GradientSumT)));
|
||||
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
|
||||
dh::MaxSharedMemoryOptin(device_id),
|
||||
sizeof(GradientSumT)));
|
||||
}
|
||||
|
||||
~GPUHistMakerDevice() { // NOLINT
|
||||
@ -263,11 +257,7 @@ struct GPUHistMakerDevice {
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
if (deterministic_histogram) {
|
||||
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
|
||||
} else {
|
||||
histogram_rounding = GradientSumT{0.0, 0.0};
|
||||
}
|
||||
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
|
||||
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows));
|
||||
@ -805,7 +795,6 @@ class GPUHistMakerSpecialised {
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_,
|
||||
hist_maker_param_.deterministic_histogram,
|
||||
batch_param));
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
* Copyright 2017-2021 XGBoost contributors
|
||||
*/
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <vector>
|
||||
#include <xgboost/base.h>
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/quantile.h"
|
||||
@ -101,8 +104,6 @@ struct IsSorted {
|
||||
} // namespace
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_duplicated) {
|
||||
std::vector<bst_feature_t> segments{0, static_cast<bst_feature_t>(values.size())};
|
||||
|
||||
@ -194,5 +195,73 @@ TEST(DeviceHelpers, ArgSort) {
|
||||
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
|
||||
thrust::greater<size_t>{}));
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
namespace {
|
||||
// Atomic add as type cast for test.
|
||||
XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT
|
||||
uint64_t* u_dst = reinterpret_cast<uint64_t*>(dst);
|
||||
uint64_t u_src = *reinterpret_cast<uint64_t*>(&src);
|
||||
uint64_t ret = ::atomicAdd(u_dst, u_src);
|
||||
return *reinterpret_cast<int64_t*>(&ret);
|
||||
}
|
||||
}
|
||||
|
||||
void TestAtomicAdd() {
|
||||
size_t n_elements = 1024;
|
||||
dh::device_vector<int64_t> result_a(1, 0);
|
||||
auto d_result_a = result_a.data().get();
|
||||
|
||||
dh::device_vector<int64_t> result_b(1, 0);
|
||||
auto d_result_b = result_b.data().get();
|
||||
|
||||
/**
|
||||
* Test for simple inputs
|
||||
*/
|
||||
std::vector<int64_t> h_inputs(n_elements);
|
||||
for (size_t i = 0; i < h_inputs.size(); ++i) {
|
||||
h_inputs[i] = (i % 2 == 0) ? i : -i;
|
||||
}
|
||||
dh::device_vector<int64_t> inputs(h_inputs);
|
||||
auto d_inputs = inputs.data().get();
|
||||
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
|
||||
/**
|
||||
* Test for positive values that don't fit into 32 bit integer.
|
||||
*/
|
||||
thrust::fill(inputs.begin(), inputs.end(),
|
||||
(std::numeric_limits<uint32_t>::max() / 2));
|
||||
thrust::fill(result_a.begin(), result_a.end(), 0);
|
||||
thrust::fill(result_b.begin(), result_b.end(), 0);
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
ASSERT_GT(result_a[0], std::numeric_limits<uint32_t>::max());
|
||||
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
|
||||
|
||||
/**
|
||||
* Test for negative values that don't fit into 32 bit integer.
|
||||
*/
|
||||
thrust::fill(inputs.begin(), inputs.end(),
|
||||
(std::numeric_limits<int32_t>::min() / 2));
|
||||
thrust::fill(result_a.begin(), result_a.end(), 0);
|
||||
thrust::fill(result_b.begin(), result_b.end(), 0);
|
||||
dh::LaunchN(n_elements, [=] __device__(size_t i) {
|
||||
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
|
||||
atomicAdd(d_result_b, d_inputs[i]);
|
||||
});
|
||||
ASSERT_EQ(result_a[0], result_b[0]);
|
||||
ASSERT_LT(result_a[0], std::numeric_limits<int32_t>::min());
|
||||
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
|
||||
}
|
||||
|
||||
TEST(AtomicAdd, Int64) {
|
||||
TestAtomicAdd();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
#include "test_ranking_obj.cc"
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "test_ranking_obj.cc"
|
||||
#include "../../../src/objective/rank_obj.cu"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/sequence.h>
|
||||
|
||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||
#include "../../helpers.h"
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 XGBoost contributors
|
||||
* Copyright 2017-2021 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <random>
|
||||
@ -80,8 +81,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
param.Init(args);
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
BatchParam batch_param{};
|
||||
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param, kNCols, kNCols,
|
||||
true, batch_param);
|
||||
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param,
|
||||
kNCols, kNCols, batch_param);
|
||||
xgboost::SimpleLCG gen;
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||
HostDeviceVector<GradientPair> gpair(kNRows);
|
||||
@ -93,14 +94,18 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
gpair.SetDevice(0);
|
||||
|
||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
|
||||
|
||||
|
||||
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
maker.hist.AllocateHistogram(0);
|
||||
maker.gpair = gpair.DeviceSpan();
|
||||
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);;
|
||||
|
||||
maker.BuildHist(0);
|
||||
DeviceHistogram<GradientSumT> d_hist = maker.hist;
|
||||
BuildGradientHistogram(
|
||||
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0),
|
||||
maker.hist.GetNodeHistogram(0), maker.histogram_rounding,
|
||||
!use_shared_memory_histograms);
|
||||
|
||||
DeviceHistogram<GradientSumT>& d_hist = maker.hist;
|
||||
|
||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||
// d_hist.data stored in float, not gradient pair
|
||||
@ -115,6 +120,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
||||
std::cout << std::fixed;
|
||||
for (size_t i = 0; i < h_result.size(); ++i) {
|
||||
ASSERT_FALSE(std::isnan(h_result[i].GetGrad()));
|
||||
EXPECT_NEAR(h_result[i].GetGrad(), solution[i].GetGrad(), 0.01f);
|
||||
EXPECT_NEAR(h_result[i].GetHess(), solution[i].GetHess(), 0.01f);
|
||||
}
|
||||
@ -158,7 +164,8 @@ TEST(GpuHist, ApplySplit) {
|
||||
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
|
||||
feature_types.SetDevice(bparam.gpu_id);
|
||||
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
|
||||
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam);
|
||||
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols,
|
||||
bparam);
|
||||
updater.ApplySplit(candidate, &tree);
|
||||
|
||||
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
|
||||
@ -217,8 +224,8 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
// Initialize GPUHistMakerDevice
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
BatchParam batch_param{};
|
||||
GPUHistMakerDevice<GradientPairPrecise>
|
||||
maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param);
|
||||
GPUHistMakerDevice<GradientPairPrecise> maker(
|
||||
0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param);
|
||||
// Initialize GPUHistMakerDevice::node_sum_gradients
|
||||
maker.node_sum_gradients = {};
|
||||
|
||||
|
||||
@ -55,9 +55,6 @@ class TestGPUBasicModels:
|
||||
model_0, model_1 = self.run_cls(X, y, True)
|
||||
assert model_0 == model_1
|
||||
|
||||
model_0, model_1 = self.run_cls(X, y, False)
|
||||
assert model_0 != model_1
|
||||
|
||||
def test_invalid_gpu_id(self):
|
||||
X = np.random.randn(10, 5) * 1e4
|
||||
y = np.random.randint(0, 2, size=10) * 1e4
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user