diff --git a/doc/parameter.rst b/doc/parameter.rst index c78be7dbe..aff9a5d18 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -245,16 +245,6 @@ Additional parameters for ``hist`` and ``gpu_hist`` tree method - Use single precision to build histograms instead of double precision. -Additional parameters for ``gpu_hist`` tree method -================================================== - -* ``deterministic_histogram``, [default=``true``] - - - Build histogram on GPU deterministically. Histogram building is not deterministic due - to the non-associative aspect of floating point summation. We employ a pre-rounding - routine to mitigate the issue, which may lead to slightly lower accuracy. Set to - ``false`` to disable it. - Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 0e3fbd981..4ea2de31a 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -255,9 +255,12 @@ class GradientPairInternal { /*! \brief gradient statistics pair usually needed in gradient boosting */ using GradientPair = detail::GradientPairInternal; - /*! \brief High precision gradient statistics pair */ using GradientPairPrecise = detail::GradientPairInternal; +/*! \brief Fixed point representation for gradient pair. */ +using GradientPairInt32 = detail::GradientPairInternal; +/*! \brief Fixed point representation for high precision gradient pair. */ +using GradientPairInt64 = detail::GradientPairInternal; using Args = std::vector >; diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 36249d5f2..96ef35058 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ #pragma once #include @@ -53,27 +53,6 @@ #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) - -#else // In device code and CUDA < 600 -__device__ __forceinline__ double atomicAdd(double* address, double val) { // NOLINT - unsigned long long int* address_as_ull = - (unsigned long long int*)address; // NOLINT - unsigned long long int old = *address_as_ull, assumed; // NOLINT - - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - - // Note: uses integer comparison to avoid hang in case of NaN (since NaN != - // NaN) - } while (assumed != old); - - return __longlong_as_double(old); -} -#endif - namespace dh { namespace detail { template @@ -98,12 +77,11 @@ template ::value && !std::is_same::value> * = // NOLINT nullptr> -T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT +XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT using Type = typename dh::detail::AtomicDispatcher::Type; Type ret = ::atomicAdd(reinterpret_cast(addr), static_cast(v)); return static_cast(ret); } - namespace dh { #ifdef XGBOOST_USE_NCCL @@ -1109,6 +1087,44 @@ XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, static_cast(gpair.GetHess())); } +/** + * \brief An atomicAdd designed for gradient pair with better performance. For general + * int64_t atomicAdd, one can simply cast it to unsigned long long. + */ +XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t *dst, int64_t src) { + uint32_t* y_low = reinterpret_cast(dst); + uint32_t *y_high = y_low + 1; + + auto cast_src = reinterpret_cast(&src); + + uint32_t const x_low = static_cast(src); + uint32_t const x_high = (*cast_src) >> 32; + + auto const old = atomicAdd(y_low, x_low); + uint32_t const carry = old > (std::numeric_limits::max() - x_low) ? 1 : 0; + uint32_t const sig = x_high + carry; + atomicAdd(y_high, sig); +} + +XGBOOST_DEV_INLINE void +AtomicAddGpair(xgboost::GradientPairInt64 *dest, + xgboost::GradientPairInt64 const &gpair) { + auto dst_ptr = reinterpret_cast(dest); + auto g = gpair.GetGrad(); + auto h = gpair.GetHess(); + + AtomicAdd64As32(dst_ptr, g); + AtomicAdd64As32(dst_ptr + 1, h); +} + +XGBOOST_DEV_INLINE void +AtomicAddGpair(xgboost::GradientPairInt32 *dest, + xgboost::GradientPairInt32 const &gpair) { + auto dst_ptr = reinterpret_cast(dest); + + ::atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); + ::atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); +} // Thrust version of this function causes error on Windows template diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index aae2fbc04..791363a05 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by XGBoost Contributors + * Copyright 2020-2021 by XGBoost Contributors */ #include #include @@ -34,7 +34,7 @@ namespace tree { * to avoid outliers, as the full reduction is reproducible on GPU with reduction tree. */ template -XGBOOST_DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) { +T CreateRoundingFactor(T max_abs, int n) { T delta = max_abs / (static_cast(1.0) - 2 * n * std::numeric_limits::epsilon()); // Calculate ceil(log_2(delta)). @@ -78,7 +78,7 @@ struct Clip : public thrust::unary_function { }; template -GradientSumT CreateRoundingFactor(common::Span gpair) { +HistRounding CreateRoundingFactor(common::Span gpair) { using T = typename GradientSumT::ValueT; dh::XGBCachingDeviceAllocator alloc; @@ -94,26 +94,51 @@ GradientSumT CreateRoundingFactor(common::Span gpair) { gpair.size()), CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), gpair.size()) }; - return histogram_rounding; + + using IntT = typename HistRounding::SharedSumT::ValueT; + + /** + * Factor for converting gradients from fixed-point to floating-point. + */ + GradientSumT to_floating_point = + histogram_rounding / + T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 - + 2)); // keep 1 for sign bit + /** + * Factor for converting gradients from floating-point to fixed-point. For + * f64: + * + * Precision = 64 - 1 - log2(rounding) + * + * rounding is calcuated as exp(m), see the rounding factor calcuation for + * details. + */ + GradientSumT to_fixed_point = GradientSumT( + T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess()); + + return {histogram_rounding, to_fixed_point, to_floating_point}; } -template GradientPairPrecise CreateRoundingFactor(common::Span gpair); -template GradientPair CreateRoundingFactor(common::Span gpair); +template HistRounding +CreateRoundingFactor(common::Span gpair); +template HistRounding +CreateRoundingFactor(common::Span gpair); -template +template __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, FeatureGroupsAccessor feature_groups, common::Span d_ridx, GradientSumT* __restrict__ d_node_hist, const GradientPair* __restrict__ d_gpair, - GradientSumT const rounding, - bool use_shared_memory_histograms) { + HistRounding const rounding) { + using SharedSumT = typename HistRounding::SharedSumT; using T = typename GradientSumT::ValueT; + extern __shared__ char smem[]; FeatureGroup group = feature_groups[blockIdx.y]; - GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT + SharedSumT *smem_arr = reinterpret_cast(smem); if (use_shared_memory_histograms) { - dh::BlockFill(smem_arr, group.num_bins, GradientSumT()); + dh::BlockFill(smem_arr, group.num_bins, SharedSumT()); __syncthreads(); } int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride; @@ -123,16 +148,21 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride]; if (gidx != matrix.NumBins()) { - GradientSumT truncated { - TruncateWithRoundingFactor(rounding.GetGrad(), d_gpair[ridx].GetGrad()), - TruncateWithRoundingFactor(rounding.GetHess(), d_gpair[ridx].GetHess()), - }; // If we are not using shared memory, accumulate the values directly into // global memory - GradientSumT* atomic_add_ptr = - use_shared_memory_histograms ? smem_arr : d_node_hist; gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx; - dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated); + if (use_shared_memory_histograms) { + auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]); + dh::AtomicAddGpair(smem_arr + gidx, adjusted); + } else { + GradientSumT truncated{ + TruncateWithRoundingFactor(rounding.rounding.GetGrad(), + d_gpair[ridx].GetGrad()), + TruncateWithRoundingFactor(rounding.rounding.GetHess(), + d_gpair[ridx].GetHess()), + }; + dh::AtomicAddGpair(d_node_hist + gidx, truncated); + } } } @@ -140,12 +170,7 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, // Write shared memory back to global memory __syncthreads(); for (auto i : dh::BlockStrideRange(0, group.num_bins)) { - GradientSumT truncated{ - TruncateWithRoundingFactor(rounding.GetGrad(), - smem_arr[i].GetGrad()), - TruncateWithRoundingFactor(rounding.GetHess(), - smem_arr[i].GetHess()), - }; + auto truncated = rounding.ToFloatingPoint(smem_arr[i]); dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); } } @@ -157,57 +182,68 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, common::Span gpair, common::Span d_ridx, common::Span histogram, - GradientSumT rounding) { + HistRounding rounding, + bool force_global_memory) { // decide whether to use shared memory int device = 0; dh::safe_cuda(cudaGetDevice(&device)); + // opt into maximum shared memory for the kernel if necessary int max_shared_memory = dh::MaxSharedMemoryOptin(device); - size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins; - bool shared = smem_size <= max_shared_memory; + + size_t smem_size = sizeof(typename HistRounding::SharedSumT) * + feature_groups.max_group_bins; + bool shared = !force_global_memory && smem_size <= max_shared_memory; smem_size = shared ? smem_size : 0; - // opt into maximum shared memory for the kernel if necessary - auto kernel = SharedMemHistKernel; + auto runit = [&](auto kernel) { + if (shared) { + dh::safe_cuda(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_memory)); + } + + // determine the launch configuration + int min_grid_size; + int block_threads = 1024; + dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize( + &min_grid_size, &block_threads, kernel, smem_size, 0)); + + int num_groups = feature_groups.NumGroups(); + int n_mps = 0; + dh::safe_cuda( + cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); + int n_blocks_per_mp = 0; + dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &n_blocks_per_mp, kernel, block_threads, smem_size)); + unsigned grid_size = n_blocks_per_mp * n_mps; + + // TODO(canonizer): This is really a hack, find a better way to distribute + // the data among thread blocks. The intention is to generate enough thread + // blocks to fill the GPU, but avoid having too many thread blocks, as this + // is less efficient when the number of rows is low. At least one thread + // block per feature group is required. The number of thread blocks: + // - for num_groups <= num_groups_threshold, around grid_size * num_groups + // - for num_groups_threshold <= num_groups <= num_groups_threshold * + // grid_size, + // around grid_size * num_groups_threshold + // - for num_groups_threshold * grid_size <= num_groups, around num_groups + int num_groups_threshold = 4; + grid_size = common::DivRoundUp( + grid_size, common::DivRoundUp(num_groups, num_groups_threshold)); + + using T = typename GradientSumT::ValueT; + dh::LaunchKernel {dim3(grid_size, num_groups), + static_cast(block_threads), + smem_size} (kernel, matrix, feature_groups, d_ridx, + histogram.data(), gpair.data(), rounding); + }; + if (shared) { - dh::safe_cuda(cudaFuncSetAttribute - (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_memory)); + runit(SharedMemHistKernel); + } else { + runit(SharedMemHistKernel); } - // determine the launch configuration - int min_grid_size; - int block_threads = 1024; - dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize( - &min_grid_size, &block_threads, kernel, smem_size, 0)); - - int num_groups = feature_groups.NumGroups(); - int n_mps = 0; - dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); - int n_blocks_per_mp = 0; - dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor - (&n_blocks_per_mp, kernel, block_threads, smem_size)); - unsigned grid_size = n_blocks_per_mp * n_mps; - - // TODO(canonizer): This is really a hack, find a better way to distribute the - // data among thread blocks. - // The intention is to generate enough thread blocks to fill the GPU, but - // avoid having too many thread blocks, as this is less efficient when the - // number of rows is low. At least one thread block per feature group is - // required. - // The number of thread blocks: - // - for num_groups <= num_groups_threshold, around grid_size * num_groups - // - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size, - // around grid_size * num_groups_threshold - // - for num_groups_threshold * grid_size <= num_groups, around num_groups - int num_groups_threshold = 4; - grid_size = common::DivRoundUp(grid_size, - common::DivRoundUp(num_groups, num_groups_threshold)); - - dh::LaunchKernel { - dim3(grid_size, num_groups), static_cast(block_threads), smem_size} ( - kernel, - matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding, - shared); dh::safe_cuda(cudaGetLastError()); } @@ -217,7 +253,8 @@ template void BuildGradientHistogram( common::Span gpair, common::Span ridx, common::Span histogram, - GradientPair rounding); + HistRounding rounding, + bool force_global_memory); template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, @@ -225,7 +262,8 @@ template void BuildGradientHistogram( common::Span gpair, common::Span ridx, common::Span histogram, - GradientPairPrecise rounding); + HistRounding rounding, + bool force_global_memory); } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 02e63bcad..a45083f76 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -12,22 +12,57 @@ namespace xgboost { namespace tree { -template -GradientSumT CreateRoundingFactor(common::Span gpair); - template XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, U const x) { static_assert(sizeof(T) >= sizeof(U), "Rounding must have higher or equal precision."); return (rounding_factor + static_cast(x)) - rounding_factor; } +/** + * Truncation factor for gradient, see comments in `CreateRoundingFactor()` for details. + */ +template +struct HistRounding { + /* Factor to truncate the gradient before building histogram for deterministic result. */ + GradientSumT rounding; + /* Convert gradient to fixed point representation. */ + GradientSumT to_fixed_point; + /* Convert fixed point representation back to floating point. */ + GradientSumT to_floating_point; + + /* Type used in shared memory. */ + using SharedSumT = std::conditional_t< + std::is_same::value, + GradientPairInt32, GradientPairInt64>; + using T = typename GradientSumT::ValueT; + + XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const { + auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()), + T(gpair.GetHess() * to_fixed_point.GetHess())); + return adjusted; + } + XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const { + auto g = gpair.GetGrad() * to_floating_point.GetGrad(); + auto h = gpair.GetHess() * to_floating_point.GetHess(); + GradientSumT truncated{ + TruncateWithRoundingFactor(rounding.GetGrad(), g), + TruncateWithRoundingFactor(rounding.GetHess(), h), + }; + return truncated; + } +}; + +template +HistRounding CreateRoundingFactor(common::Span gpair); + template void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, - GradientSumT rounding); + HistRounding rounding, + bool force_global_memory = false); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 1e2673f05..7499293e7 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -46,14 +46,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); struct GPUHistMakerTrainParam : public XGBoostParameter { bool single_precision_histogram; - bool deterministic_histogram; bool debug_synchronize; // declare parameters DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( "Use single precision to build histograms."); - DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe( - "Pre-round the gradient for obtaining deterministic gradient histogram."); DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( "Check if all distributed tree are identical after tree construction."); } @@ -153,7 +150,7 @@ class DeviceHistogram { */ common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); - auto ptr = data_.data().get() + nidx_map_[nidx]; + auto ptr = data_.data().get() + nidx_map_.at(nidx); return common::Span( reinterpret_cast(ptr), n_bins_); } @@ -179,9 +176,8 @@ struct GPUHistMakerDevice { std::vector node_sum_gradients; TrainParam param; - bool deterministic_histogram; - GradientSumT histogram_rounding; + HistRounding histogram_rounding; dh::PinnedMemory pinned; @@ -205,7 +201,6 @@ struct GPUHistMakerDevice { TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, - bool deterministic_histogram, BatchParam _batch_param) : device_id(_device_id), page(_page), @@ -214,7 +209,6 @@ struct GPUHistMakerDevice { tree_evaluator(param, n_features, _device_id), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), - deterministic_histogram{deterministic_histogram}, batch_param(_batch_param) { sampler.reset(new GradientBasedSampler( page, _n_rows, batch_param, param.subsample, param.sampling_method)); @@ -227,9 +221,9 @@ struct GPUHistMakerDevice { // Init histogram hist.Init(device_id, page->Cuts().TotalBins()); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); - feature_groups.reset(new FeatureGroups( - page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), - sizeof(GradientSumT))); + feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, + dh::MaxSharedMemoryOptin(device_id), + sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT @@ -263,11 +257,7 @@ struct GPUHistMakerDevice { page = sample.page; gpair = sample.gpair; - if (deterministic_histogram) { - histogram_rounding = CreateRoundingFactor(this->gpair); - } else { - histogram_rounding = GradientSumT{0.0, 0.0}; - } + histogram_rounding = CreateRoundingFactor(this->gpair); row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows)); @@ -805,7 +795,6 @@ class GPUHistMakerSpecialised { param_, column_sampling_seed, info_->num_col_, - hist_maker_param_.deterministic_histogram, batch_param)); p_last_fmat_ = dmat; diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index cb7176c00..6e8668bd2 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -1,7 +1,10 @@ /*! - * Copyright 2017 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ +#include +#include #include +#include #include #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/quantile.h" @@ -101,8 +104,6 @@ struct IsSorted { } // namespace namespace xgboost { -namespace common { - void TestSegmentedUniqueRegression(std::vector values, size_t n_duplicated) { std::vector segments{0, static_cast(values.size())}; @@ -194,5 +195,73 @@ TEST(DeviceHelpers, ArgSort) { ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(), thrust::greater{})); } -} // namespace common + +namespace { +// Atomic add as type cast for test. +XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT + uint64_t* u_dst = reinterpret_cast(dst); + uint64_t u_src = *reinterpret_cast(&src); + uint64_t ret = ::atomicAdd(u_dst, u_src); + return *reinterpret_cast(&ret); +} +} + +void TestAtomicAdd() { + size_t n_elements = 1024; + dh::device_vector result_a(1, 0); + auto d_result_a = result_a.data().get(); + + dh::device_vector result_b(1, 0); + auto d_result_b = result_b.data().get(); + + /** + * Test for simple inputs + */ + std::vector h_inputs(n_elements); + for (size_t i = 0; i < h_inputs.size(); ++i) { + h_inputs[i] = (i % 2 == 0) ? i : -i; + } + dh::device_vector inputs(h_inputs); + auto d_inputs = inputs.data().get(); + + dh::LaunchN(n_elements, [=] __device__(size_t i) { + dh::AtomicAdd64As32(d_result_a, d_inputs[i]); + atomicAdd(d_result_b, d_inputs[i]); + }); + ASSERT_EQ(result_a[0], result_b[0]); + + /** + * Test for positive values that don't fit into 32 bit integer. + */ + thrust::fill(inputs.begin(), inputs.end(), + (std::numeric_limits::max() / 2)); + thrust::fill(result_a.begin(), result_a.end(), 0); + thrust::fill(result_b.begin(), result_b.end(), 0); + dh::LaunchN(n_elements, [=] __device__(size_t i) { + dh::AtomicAdd64As32(d_result_a, d_inputs[i]); + atomicAdd(d_result_b, d_inputs[i]); + }); + ASSERT_EQ(result_a[0], result_b[0]); + ASSERT_GT(result_a[0], std::numeric_limits::max()); + CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]); + + /** + * Test for negative values that don't fit into 32 bit integer. + */ + thrust::fill(inputs.begin(), inputs.end(), + (std::numeric_limits::min() / 2)); + thrust::fill(result_a.begin(), result_a.end(), 0); + thrust::fill(result_b.begin(), result_b.end(), 0); + dh::LaunchN(n_elements, [=] __device__(size_t i) { + dh::AtomicAdd64As32(d_result_a, d_inputs[i]); + atomicAdd(d_result_b, d_inputs[i]); + }); + ASSERT_EQ(result_a[0], result_b[0]); + ASSERT_LT(result_a[0], std::numeric_limits::min()); + CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]); +} + +TEST(AtomicAdd, Int64) { + TestAtomicAdd(); +} } // namespace xgboost diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index dc8fd267d..4cf736bf6 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -1,5 +1,9 @@ -#include "test_ranking_obj.cc" +/*! + * Copyright 2019-2021 by XGBoost Contributors + */ +#include +#include "test_ranking_obj.cc" #include "../../../src/objective/rank_obj.cu" namespace xgboost { diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 4879ca080..9b16cca53 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -1,8 +1,13 @@ +/*! + * Copyright 2019-2021 by XGBoost Contributors + */ #include #include #include +#include #include + #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../helpers.h" diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 591dc43d2..72c225396 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -1,8 +1,9 @@ /*! - * Copyright 2017-2020 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ #include #include +#include #include #include #include @@ -80,8 +81,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, - true, batch_param); + GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, + kNCols, kNCols, batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); @@ -93,14 +94,18 @@ void TestBuildHist(bool use_shared_memory_histograms) { gpair.SetDevice(0); thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); - - maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); maker.hist.AllocateHistogram(0); maker.gpair = gpair.DeviceSpan(); + maker.histogram_rounding = CreateRoundingFactor(maker.gpair);; - maker.BuildHist(0); - DeviceHistogram d_hist = maker.hist; + BuildGradientHistogram( + page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), + gpair.DeviceSpan(), maker.row_partitioner->GetRows(0), + maker.hist.GetNodeHistogram(0), maker.histogram_rounding, + !use_shared_memory_histograms); + + DeviceHistogram& d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -115,6 +120,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { std::vector solution = GetHostHistGpair(); std::cout << std::fixed; for (size_t i = 0; i < h_result.size(); ++i) { + ASSERT_FALSE(std::isnan(h_result[i].GetGrad())); EXPECT_NEAR(h_result[i].GetGrad(), solution[i].GetGrad(), 0.01f); EXPECT_NEAR(h_result[i].GetHess(), solution[i].GetHess(), 0.01f); } @@ -158,7 +164,8 @@ TEST(GpuHist, ApplySplit) { HostDeviceVector feature_types(10, FeatureType::kCategorical); feature_types.SetDevice(bparam.gpu_id); tree::GPUHistMakerDevice updater( - 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam); + 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, + bparam); updater.ApplySplit(candidate, &tree); ASSERT_EQ(tree.GetSplitTypes().size(), 3); @@ -217,8 +224,8 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice - maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); + GPUHistMakerDevice maker( + 0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {}; diff --git a/tests/python-gpu/test_gpu_basic_models.py b/tests/python-gpu/test_gpu_basic_models.py index dc556fdae..3f4099986 100644 --- a/tests/python-gpu/test_gpu_basic_models.py +++ b/tests/python-gpu/test_gpu_basic_models.py @@ -55,9 +55,6 @@ class TestGPUBasicModels: model_0, model_1 = self.run_cls(X, y, True) assert model_0 == model_1 - model_0, model_1 = self.run_cls(X, y, False) - assert model_0 != model_1 - def test_invalid_gpu_id(self): X = np.random.randn(10, 5) * 1e4 y = np.random.randint(0, 2, size=10) * 1e4