diff --git a/doc/gpu/index.md b/doc/gpu/index.md index 4e20a87a9..da913f2f1 100644 --- a/doc/gpu/index.md +++ b/doc/gpu/index.md @@ -12,13 +12,13 @@ Specify the 'tree_method' parameter as one of the following algorithms. ### Algorithms ```eval_rst -+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ -| tree_method | Description | -+==============+===============================================================================================================================================+ -| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' | -+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ -| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Faster and uses considerably less memory. Splits may be less accurate. | -+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ ++--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| tree_method | Description | ++==============+=================================================================================================================================================================================================================+ +| gpu_exact | The standard XGBoost tree construction algorithm. Performs exact search for splits. Slower and uses considerably more memory than 'gpu_hist' | ++--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| gpu_hist | Equivalent to the XGBoost fast histogram algorithm. Much faster and uses considerably less memory. NOTE: Cannot be used with labels larger in magnitude than 2^16 due to it's histogram aggregation algorithm. | ++--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` ### Supported parameters diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 117c2f7a0..24ff4d397 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -171,19 +171,19 @@ class bst_gpair_internal { template<> inline XGBOOST_DEVICE float bst_gpair_internal::GetGrad() const { - return grad_ * 1e-5f; + return grad_ * 1e-4f; } template<> inline XGBOOST_DEVICE float bst_gpair_internal::GetHess() const { - return hess_ * 1e-5f; + return hess_ * 1e-4f; } template<> inline XGBOOST_DEVICE void bst_gpair_internal::SetGrad(float g) { - grad_ = static_cast(std::round(g * 1e5)); + grad_ = static_cast(std::round(g * 1e4)); } template<> inline XGBOOST_DEVICE void bst_gpair_internal::SetHess(float h) { - hess_ = static_cast(std::round(h * 1e5)); + hess_ = static_cast(std::round(h * 1e4)); } } // namespace detail @@ -194,10 +194,10 @@ typedef detail::bst_gpair_internal bst_gpair; /*! \brief High precision gradient statistics pair */ typedef detail::bst_gpair_internal bst_gpair_precise; - /*! \brief High precision gradient statistics pair with integer backed - * storage. Operators are associative where floating point versions are not - * associative. */ - typedef detail::bst_gpair_internal bst_gpair_integer; +/*! \brief High precision gradient statistics pair with integer backed + * storage. Operators are associative where floating point versions are not + * associative. */ +typedef detail::bst_gpair_internal bst_gpair_integer; /*! \brief small eps gap for minimum split decision. */ const bst_float rt_eps = 1e-6f; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 32513d7a4..8edaf8578 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -15,6 +15,27 @@ namespace xgboost { namespace tree { +/** + * \fn void CheckGradientMax(const dh::dvec& gpair) + * + * \brief Check maximum gradient value is below 2^16. This is to prevent + * overflow when using integer gradient summation. + */ + +inline void CheckGradientMax(const dh::dvec& gpair) { + auto dptr = thrust::device_ptr( + reinterpret_cast(gpair.data())); + float abs_max = thrust::reduce(dptr, dptr + (gpair.size() * 2), 0.f, + [=] __device__(float a, float b) { + a = abs(a); + b = abs(b); + return max(a, b); + }); + + CHECK_LT(abs_max, std::pow(2.0f, 16.0f)) + << "Labels are too large for this algorithm. Rescale to less than 2^16."; +} + struct GPUTrainingParam { // minimum amount of hessian(weight) allowed in a child float min_child_weight; @@ -64,8 +85,8 @@ struct DeviceSplitCandidate { : loss_chg(-FLT_MAX), dir(LeftDir), fvalue(0), findex(-1) {} template - __host__ __device__ void Update(const DeviceSplitCandidate &other, - const param_t& param) { + __host__ __device__ void Update(const DeviceSplitCandidate& other, + const param_t& param) { if (other.loss_chg > loss_chg && other.left_sum.GetHess() >= param.min_child_weight && other.right_sum.GetHess() >= param.min_child_weight) { @@ -170,8 +191,10 @@ struct SumCallbackOp { }; template -__device__ inline float device_calc_loss_chg( - const GPUTrainingParam& param, const gpair_t& left, const gpair_t& parent_sum, const float& parent_gain) { +__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param, + const gpair_t& left, + const gpair_t& parent_sum, + const float& parent_gain) { gpair_t right = parent_sum - left; float left_gain = CalcGain(param, left.GetGrad(), left.GetHess()); float right_gain = CalcGain(param, right.GetGrad(), right.GetHess()); @@ -187,8 +210,8 @@ __device__ float inline loss_chg_missing(const gpair_t& scan, bool& missing_left_out) { // NOLINT float missing_left_loss = device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain); - float missing_right_loss = device_calc_loss_chg( - param, scan, parent_sum, parent_gain); + float missing_right_loss = + device_calc_loss_chg(param, scan, parent_sum, parent_gain); if (missing_left_loss >= missing_right_loss) { missing_left_out = true; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index bd7ca94e1..dd3f1a8e9 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -537,6 +537,9 @@ class GPUHistMaker : public TreeUpdater { device_gpair[d_idx].copy(gpair.begin() + device_row_segments[d_idx], gpair.begin() + device_row_segments[d_idx + 1]); + // Check gradients are within acceptable size range + CheckGradientMax(device_gpair[d_idx]); + subsample_gpair(&device_gpair[d_idx], param.subsample, device_row_segments[d_idx]); diff --git a/src/tree/updater_gpu_hist_experimental.cu b/src/tree/updater_gpu_hist_experimental.cu index 6b80b6100..977e3c637 100644 --- a/src/tree/updater_gpu_hist_experimental.cu +++ b/src/tree/updater_gpu_hist_experimental.cu @@ -334,6 +334,8 @@ struct DeviceShard { ridx_segments.front() = std::make_pair(0, ridx.size()); this->gpair.copy(host_gpair.begin() + row_start_idx, host_gpair.begin() + row_end_idx); + // Check gradients are within acceptable size range + CheckGradientMax(gpair); hist.Reset(); } @@ -551,8 +553,8 @@ class GPUHistMakerExperimental : public TreeUpdater { __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { unsigned ballot = __ballot(val == left_nidx); if (threadIdx.x % 32 == 0) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT + atomicAdd(reinterpret_cast(d_count), // NOLINT + static_cast(__popc(ballot))); // NOLINT } }