[breaking] Use integer atomic for GPU histogram. (#7180)

On GPU we use rouding factor to truncate the gradient for deterministic results. This PR changes the gradient representation to fixed point number with exponent aligned with rounding factor.

    [breaking] Drop non-deterministic histogram.
    Use fixed point for shared memory.

This PR is to improve the performance of GPU Hist. 

Co-authored-by: Andy Adinets <aadinets@nvidia.com>
This commit is contained in:
Jiaming Yuan 2021-08-28 05:17:05 +08:00 committed by GitHub
parent e7d7ab6bc3
commit 7a1d67f9cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 295 additions and 142 deletions

View File

@ -245,16 +245,6 @@ Additional parameters for ``hist`` and ``gpu_hist`` tree method
- Use single precision to build histograms instead of double precision.
Additional parameters for ``gpu_hist`` tree method
==================================================
* ``deterministic_histogram``, [default=``true``]
- Build histogram on GPU deterministically. Histogram building is not deterministic due
to the non-associative aspect of floating point summation. We employ a pre-rounding
routine to mitigate the issue, which may lead to slightly lower accuracy. Set to
``false`` to disable it.
Additional parameters for Dart Booster (``booster=dart``)
=========================================================

View File

@ -255,9 +255,12 @@ class GradientPairInternal {
/*! \brief gradient statistics pair usually needed in gradient boosting */
using GradientPair = detail::GradientPairInternal<float>;
/*! \brief High precision gradient statistics pair */
using GradientPairPrecise = detail::GradientPairInternal<double>;
/*! \brief Fixed point representation for gradient pair. */
using GradientPairInt32 = detail::GradientPairInternal<int>;
/*! \brief Fixed point representation for high precision gradient pair. */
using GradientPairInt64 = detail::GradientPairInternal<int64_t>;
using Args = std::vector<std::pair<std::string, std::string> >;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2020 XGBoost contributors
* Copyright 2017-2021 XGBoost contributors
*/
#pragma once
#include <thrust/device_ptr.h>
@ -53,27 +53,6 @@
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
#else // In device code and CUDA < 600
__device__ __forceinline__ double atomicAdd(double* address, double val) { // NOLINT
unsigned long long int* address_as_ull =
(unsigned long long int*)address; // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
// NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
namespace dh {
namespace detail {
template <size_t size>
@ -98,12 +77,11 @@ template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> * = // NOLINT
nullptr>
T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT
XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT
using Type = typename dh::detail::AtomicDispatcher<sizeof(T)>::Type;
Type ret = ::atomicAdd(reinterpret_cast<Type *>(addr), static_cast<Type>(v));
return static_cast<T>(ret);
}
namespace dh {
#ifdef XGBOOST_USE_NCCL
@ -1109,6 +1087,44 @@ XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
}
/**
* \brief An atomicAdd designed for gradient pair with better performance. For general
* int64_t atomicAdd, one can simply cast it to unsigned long long.
*/
XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t *dst, int64_t src) {
uint32_t* y_low = reinterpret_cast<uint32_t *>(dst);
uint32_t *y_high = y_low + 1;
auto cast_src = reinterpret_cast<uint64_t *>(&src);
uint32_t const x_low = static_cast<uint32_t>(src);
uint32_t const x_high = (*cast_src) >> 32;
auto const old = atomicAdd(y_low, x_low);
uint32_t const carry = old > (std::numeric_limits<uint32_t>::max() - x_low) ? 1 : 0;
uint32_t const sig = x_high + carry;
atomicAdd(y_high, sig);
}
XGBOOST_DEV_INLINE void
AtomicAddGpair(xgboost::GradientPairInt64 *dest,
xgboost::GradientPairInt64 const &gpair) {
auto dst_ptr = reinterpret_cast<int64_t *>(dest);
auto g = gpair.GetGrad();
auto h = gpair.GetHess();
AtomicAdd64As32(dst_ptr, g);
AtomicAdd64As32(dst_ptr + 1, h);
}
XGBOOST_DEV_INLINE void
AtomicAddGpair(xgboost::GradientPairInt32 *dest,
xgboost::GradientPairInt32 const &gpair) {
auto dst_ptr = reinterpret_cast<typename xgboost::GradientPairInt32::ValueT*>(dest);
::atomicAdd(dst_ptr, static_cast<int>(gpair.GetGrad()));
::atomicAdd(dst_ptr + 1, static_cast<int>(gpair.GetHess()));
}
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2020 by XGBoost Contributors
* Copyright 2020-2021 by XGBoost Contributors
*/
#include <thrust/reduce.h>
#include <thrust/iterator/transform_iterator.h>
@ -34,7 +34,7 @@ namespace tree {
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
*/
template <typename T>
XGBOOST_DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) {
T CreateRoundingFactor(T max_abs, int n) {
T delta = max_abs / (static_cast<T>(1.0) - 2 * n * std::numeric_limits<T>::epsilon());
// Calculate ceil(log_2(delta)).
@ -78,7 +78,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
};
template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const> gpair) {
using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;
@ -94,26 +94,51 @@ GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair) {
gpair.size()),
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
gpair.size()) };
return histogram_rounding;
using IntT = typename HistRounding<GradientSumT>::SharedSumT::ValueT;
/**
* Factor for converting gradients from fixed-point to floating-point.
*/
GradientSumT to_floating_point =
histogram_rounding /
T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 -
2)); // keep 1 for sign bit
/**
* Factor for converting gradients from floating-point to fixed-point. For
* f64:
*
* Precision = 64 - 1 - log2(rounding)
*
* rounding is calcuated as exp(m), see the rounding factor calcuation for
* details.
*/
GradientSumT to_fixed_point = GradientSumT(
T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess());
return {histogram_rounding, to_fixed_point, to_floating_point};
}
template GradientPairPrecise CreateRoundingFactor(common::Span<GradientPair const> gpair);
template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpair);
template HistRounding<GradientPairPrecise>
CreateRoundingFactor(common::Span<GradientPair const> gpair);
template HistRounding<GradientPair>
CreateRoundingFactor(common::Span<GradientPair const> gpair);
template <typename GradientSumT>
template <typename GradientSumT, bool use_shared_memory_histograms>
__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
FeatureGroupsAccessor feature_groups,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
GradientSumT const rounding,
bool use_shared_memory_histograms) {
HistRounding<GradientSumT> const rounding) {
using SharedSumT = typename HistRounding<GradientSumT>::SharedSumT;
using T = typename GradientSumT::ValueT;
extern __shared__ char smem[];
FeatureGroup group = feature_groups[blockIdx.y];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
SharedSumT *smem_arr = reinterpret_cast<SharedSumT *>(smem);
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, group.num_bins, GradientSumT());
dh::BlockFill(smem_arr, group.num_bins, SharedSumT());
__syncthreads();
}
int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride;
@ -123,16 +148,21 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature +
idx % feature_stride];
if (gidx != matrix.NumBins()) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), d_gpair[ridx].GetHess()),
};
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx;
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
if (use_shared_memory_histograms) {
auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]);
dh::AtomicAddGpair(smem_arr + gidx, adjusted);
} else {
GradientSumT truncated{
TruncateWithRoundingFactor<T>(rounding.rounding.GetGrad(),
d_gpair[ridx].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.rounding.GetHess(),
d_gpair[ridx].GetHess()),
};
dh::AtomicAddGpair(d_node_hist + gidx, truncated);
}
}
}
@ -140,12 +170,7 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(0, group.num_bins)) {
GradientSumT truncated{
TruncateWithRoundingFactor<T>(rounding.GetGrad(),
smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(),
smem_arr[i].GetHess()),
};
auto truncated = rounding.ToFloatingPoint(smem_arr[i]);
dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated);
}
}
@ -157,57 +182,68 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientSumT> histogram,
GradientSumT rounding) {
HistRounding<GradientSumT> rounding,
bool force_global_memory) {
// decide whether to use shared memory
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
// opt into maximum shared memory for the kernel if necessary
int max_shared_memory = dh::MaxSharedMemoryOptin(device);
size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins;
bool shared = smem_size <= max_shared_memory;
size_t smem_size = sizeof(typename HistRounding<GradientSumT>::SharedSumT) *
feature_groups.max_group_bins;
bool shared = !force_global_memory && smem_size <= max_shared_memory;
smem_size = shared ? smem_size : 0;
// opt into maximum shared memory for the kernel if necessary
auto kernel = SharedMemHistKernel<GradientSumT>;
auto runit = [&](auto kernel) {
if (shared) {
dh::safe_cuda(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_memory));
}
// determine the launch configuration
int min_grid_size;
int block_threads = 1024;
dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize(
&min_grid_size, &block_threads, kernel, smem_size, 0));
int num_groups = feature_groups.NumGroups();
int n_mps = 0;
dh::safe_cuda(
cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&n_blocks_per_mp, kernel, block_threads, smem_size));
unsigned grid_size = n_blocks_per_mp * n_mps;
// TODO(canonizer): This is really a hack, find a better way to distribute
// the data among thread blocks. The intention is to generate enough thread
// blocks to fill the GPU, but avoid having too many thread blocks, as this
// is less efficient when the number of rows is low. At least one thread
// block per feature group is required. The number of thread blocks:
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
// - for num_groups_threshold <= num_groups <= num_groups_threshold *
// grid_size,
// around grid_size * num_groups_threshold
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
int num_groups_threshold = 4;
grid_size = common::DivRoundUp(
grid_size, common::DivRoundUp(num_groups, num_groups_threshold));
using T = typename GradientSumT::ValueT;
dh::LaunchKernel {dim3(grid_size, num_groups),
static_cast<uint32_t>(block_threads),
smem_size} (kernel, matrix, feature_groups, d_ridx,
histogram.data(), gpair.data(), rounding);
};
if (shared) {
dh::safe_cuda(cudaFuncSetAttribute
(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_memory));
runit(SharedMemHistKernel<GradientSumT, true>);
} else {
runit(SharedMemHistKernel<GradientSumT, false>);
}
// determine the launch configuration
int min_grid_size;
int block_threads = 1024;
dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize(
&min_grid_size, &block_threads, kernel, smem_size, 0));
int num_groups = feature_groups.NumGroups();
int n_mps = 0;
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor
(&n_blocks_per_mp, kernel, block_threads, smem_size));
unsigned grid_size = n_blocks_per_mp * n_mps;
// TODO(canonizer): This is really a hack, find a better way to distribute the
// data among thread blocks.
// The intention is to generate enough thread blocks to fill the GPU, but
// avoid having too many thread blocks, as this is less efficient when the
// number of rows is low. At least one thread block per feature group is
// required.
// The number of thread blocks:
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
// - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size,
// around grid_size * num_groups_threshold
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
int num_groups_threshold = 4;
grid_size = common::DivRoundUp(grid_size,
common::DivRoundUp(num_groups, num_groups_threshold));
dh::LaunchKernel {
dim3(grid_size, num_groups), static_cast<uint32_t>(block_threads), smem_size} (
kernel,
matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding,
shared);
dh::safe_cuda(cudaGetLastError());
}
@ -217,7 +253,8 @@ template void BuildGradientHistogram<GradientPair>(
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
GradientPair rounding);
HistRounding<GradientPair> rounding,
bool force_global_memory);
template void BuildGradientHistogram<GradientPairPrecise>(
EllpackDeviceAccessor const& matrix,
@ -225,7 +262,8 @@ template void BuildGradientHistogram<GradientPairPrecise>(
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram,
GradientPairPrecise rounding);
HistRounding<GradientPairPrecise> rounding,
bool force_global_memory);
} // namespace tree
} // namespace xgboost

View File

@ -12,22 +12,57 @@
namespace xgboost {
namespace tree {
template <typename GradientSumT>
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
template <typename T, typename U>
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) {
static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision.");
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
}
/**
* Truncation factor for gradient, see comments in `CreateRoundingFactor()` for details.
*/
template <typename GradientSumT>
struct HistRounding {
/* Factor to truncate the gradient before building histogram for deterministic result. */
GradientSumT rounding;
/* Convert gradient to fixed point representation. */
GradientSumT to_fixed_point;
/* Convert fixed point representation back to floating point. */
GradientSumT to_floating_point;
/* Type used in shared memory. */
using SharedSumT = std::conditional_t<
std::is_same<typename GradientSumT::ValueT, float>::value,
GradientPairInt32, GradientPairInt64>;
using T = typename GradientSumT::ValueT;
XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const {
auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()),
T(gpair.GetHess() * to_fixed_point.GetHess()));
return adjusted;
}
XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const {
auto g = gpair.GetGrad() * to_floating_point.GetGrad();
auto h = gpair.GetHess() * to_floating_point.GetHess();
GradientSumT truncated{
TruncateWithRoundingFactor<T>(rounding.GetGrad(), g),
TruncateWithRoundingFactor<T>(rounding.GetHess(), h),
};
return truncated;
}
};
template <typename GradientSumT>
HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const> gpair);
template <typename GradientSumT>
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientSumT> histogram,
GradientSumT rounding);
HistRounding<GradientSumT> rounding,
bool force_global_memory = false);
} // namespace tree
} // namespace xgboost

View File

@ -46,14 +46,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
struct GPUHistMakerTrainParam
: public XGBoostParameter<GPUHistMakerTrainParam> {
bool single_precision_histogram;
bool deterministic_histogram;
bool debug_synchronize;
// declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe(
"Pre-round the gradient for obtaining deterministic gradient histogram.");
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
"Check if all distributed tree are identical after tree construction.");
}
@ -153,7 +150,7 @@ class DeviceHistogram {
*/
common::Span<GradientSumT> GetNodeHistogram(int nidx) {
CHECK(this->HistogramExists(nidx));
auto ptr = data_.data().get() + nidx_map_[nidx];
auto ptr = data_.data().get() + nidx_map_.at(nidx);
return common::Span<GradientSumT>(
reinterpret_cast<GradientSumT*>(ptr), n_bins_);
}
@ -179,9 +176,8 @@ struct GPUHistMakerDevice {
std::vector<GradientPair> node_sum_gradients;
TrainParam param;
bool deterministic_histogram;
GradientSumT histogram_rounding;
HistRounding<GradientSumT> histogram_rounding;
dh::PinnedMemory pinned;
@ -205,7 +201,6 @@ struct GPUHistMakerDevice {
TrainParam _param,
uint32_t column_sampler_seed,
uint32_t n_features,
bool deterministic_histogram,
BatchParam _batch_param)
: device_id(_device_id),
page(_page),
@ -214,7 +209,6 @@ struct GPUHistMakerDevice {
tree_evaluator(param, n_features, _device_id),
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
deterministic_histogram{deterministic_histogram},
batch_param(_batch_param) {
sampler.reset(new GradientBasedSampler(
page, _n_rows, batch_param, param.subsample, param.sampling_method));
@ -227,9 +221,9 @@ struct GPUHistMakerDevice {
// Init histogram
hist.Init(device_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
feature_groups.reset(new FeatureGroups(
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id),
sizeof(GradientSumT)));
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
dh::MaxSharedMemoryOptin(device_id),
sizeof(GradientSumT)));
}
~GPUHistMakerDevice() { // NOLINT
@ -263,11 +257,7 @@ struct GPUHistMakerDevice {
page = sample.page;
gpair = sample.gpair;
if (deterministic_histogram) {
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
} else {
histogram_rounding = GradientSumT{0.0, 0.0};
}
histogram_rounding = CreateRoundingFactor<GradientSumT>(this->gpair);
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows));
@ -805,7 +795,6 @@ class GPUHistMakerSpecialised {
param_,
column_sampling_seed,
info_->num_col_,
hist_maker_param_.deterministic_histogram,
batch_param));
p_last_fmat_ = dmat;

View File

@ -1,7 +1,10 @@
/*!
* Copyright 2017 XGBoost contributors
* Copyright 2017-2021 XGBoost contributors
*/
#include <cstddef>
#include <cstdint>
#include <thrust/device_vector.h>
#include <vector>
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
@ -101,8 +104,6 @@ struct IsSorted {
} // namespace
namespace xgboost {
namespace common {
void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_duplicated) {
std::vector<bst_feature_t> segments{0, static_cast<bst_feature_t>(values.size())};
@ -194,5 +195,73 @@ TEST(DeviceHelpers, ArgSort) {
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
thrust::greater<size_t>{}));
}
} // namespace common
namespace {
// Atomic add as type cast for test.
XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT
uint64_t* u_dst = reinterpret_cast<uint64_t*>(dst);
uint64_t u_src = *reinterpret_cast<uint64_t*>(&src);
uint64_t ret = ::atomicAdd(u_dst, u_src);
return *reinterpret_cast<int64_t*>(&ret);
}
}
void TestAtomicAdd() {
size_t n_elements = 1024;
dh::device_vector<int64_t> result_a(1, 0);
auto d_result_a = result_a.data().get();
dh::device_vector<int64_t> result_b(1, 0);
auto d_result_b = result_b.data().get();
/**
* Test for simple inputs
*/
std::vector<int64_t> h_inputs(n_elements);
for (size_t i = 0; i < h_inputs.size(); ++i) {
h_inputs[i] = (i % 2 == 0) ? i : -i;
}
dh::device_vector<int64_t> inputs(h_inputs);
auto d_inputs = inputs.data().get();
dh::LaunchN(n_elements, [=] __device__(size_t i) {
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
atomicAdd(d_result_b, d_inputs[i]);
});
ASSERT_EQ(result_a[0], result_b[0]);
/**
* Test for positive values that don't fit into 32 bit integer.
*/
thrust::fill(inputs.begin(), inputs.end(),
(std::numeric_limits<uint32_t>::max() / 2));
thrust::fill(result_a.begin(), result_a.end(), 0);
thrust::fill(result_b.begin(), result_b.end(), 0);
dh::LaunchN(n_elements, [=] __device__(size_t i) {
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
atomicAdd(d_result_b, d_inputs[i]);
});
ASSERT_EQ(result_a[0], result_b[0]);
ASSERT_GT(result_a[0], std::numeric_limits<uint32_t>::max());
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
/**
* Test for negative values that don't fit into 32 bit integer.
*/
thrust::fill(inputs.begin(), inputs.end(),
(std::numeric_limits<int32_t>::min() / 2));
thrust::fill(result_a.begin(), result_a.end(), 0);
thrust::fill(result_b.begin(), result_b.end(), 0);
dh::LaunchN(n_elements, [=] __device__(size_t i) {
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
atomicAdd(d_result_b, d_inputs[i]);
});
ASSERT_EQ(result_a[0], result_b[0]);
ASSERT_LT(result_a[0], std::numeric_limits<int32_t>::min());
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
}
TEST(AtomicAdd, Int64) {
TestAtomicAdd();
}
} // namespace xgboost

View File

@ -1,5 +1,9 @@
#include "test_ranking_obj.cc"
/*!
* Copyright 2019-2021 by XGBoost Contributors
*/
#include <thrust/host_vector.h>
#include "test_ranking_obj.cc"
#include "../../../src/objective/rank_obj.cu"
namespace xgboost {

View File

@ -1,8 +1,13 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <vector>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/sequence.h>
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../helpers.h"

View File

@ -1,8 +1,9 @@
/*!
* Copyright 2017-2020 XGBoost contributors
* Copyright 2017-2021 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <dmlc/filesystem.h>
#include <xgboost/base.h>
#include <random>
@ -80,8 +81,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
param.Init(args);
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param, kNCols, kNCols,
true, batch_param);
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param,
kNCols, kNCols, batch_param);
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows);
@ -93,14 +94,18 @@ void TestBuildHist(bool use_shared_memory_histograms) {
gpair.SetDevice(0);
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
maker.hist.AllocateHistogram(0);
maker.gpair = gpair.DeviceSpan();
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);;
maker.BuildHist(0);
DeviceHistogram<GradientSumT> d_hist = maker.hist;
BuildGradientHistogram(
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0),
maker.hist.GetNodeHistogram(0), maker.histogram_rounding,
!use_shared_memory_histograms);
DeviceHistogram<GradientSumT>& d_hist = maker.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
// d_hist.data stored in float, not gradient pair
@ -115,6 +120,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
std::cout << std::fixed;
for (size_t i = 0; i < h_result.size(); ++i) {
ASSERT_FALSE(std::isnan(h_result[i].GetGrad()));
EXPECT_NEAR(h_result[i].GetGrad(), solution[i].GetGrad(), 0.01f);
EXPECT_NEAR(h_result[i].GetHess(), solution[i].GetHess(), 0.01f);
}
@ -158,7 +164,8 @@ TEST(GpuHist, ApplySplit) {
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam);
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols,
bparam);
updater.ApplySplit(candidate, &tree);
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
@ -217,8 +224,8 @@ TEST(GpuHist, EvaluateRootSplit) {
// Initialize GPUHistMakerDevice
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientPairPrecise>
maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param);
GPUHistMakerDevice<GradientPairPrecise> maker(
0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {};

View File

@ -55,9 +55,6 @@ class TestGPUBasicModels:
model_0, model_1 = self.run_cls(X, y, True)
assert model_0 == model_1
model_0, model_1 = self.run_cls(X, y, False)
assert model_0 != model_1
def test_invalid_gpu_id(self):
X = np.random.randn(10, 5) * 1e4
y = np.random.randint(0, 2, size=10) * 1e4