Add warnings for large labels when using GPU histogram algorithms (#2834)
This commit is contained in:
@@ -15,6 +15,27 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
/**
|
||||
* \fn void CheckGradientMax(const dh::dvec<bst_gpair>& gpair)
|
||||
*
|
||||
* \brief Check maximum gradient value is below 2^16. This is to prevent
|
||||
* overflow when using integer gradient summation.
|
||||
*/
|
||||
|
||||
inline void CheckGradientMax(const dh::dvec<bst_gpair>& gpair) {
|
||||
auto dptr = thrust::device_ptr<const float>(
|
||||
reinterpret_cast<const float*>(gpair.data()));
|
||||
float abs_max = thrust::reduce(dptr, dptr + (gpair.size() * 2), 0.f,
|
||||
[=] __device__(float a, float b) {
|
||||
a = abs(a);
|
||||
b = abs(b);
|
||||
return max(a, b);
|
||||
});
|
||||
|
||||
CHECK_LT(abs_max, std::pow(2.0f, 16.0f))
|
||||
<< "Labels are too large for this algorithm. Rescale to less than 2^16.";
|
||||
}
|
||||
|
||||
struct GPUTrainingParam {
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
float min_child_weight;
|
||||
@@ -64,8 +85,8 @@ struct DeviceSplitCandidate {
|
||||
: loss_chg(-FLT_MAX), dir(LeftDir), fvalue(0), findex(-1) {}
|
||||
|
||||
template <typename param_t>
|
||||
__host__ __device__ void Update(const DeviceSplitCandidate &other,
|
||||
const param_t& param) {
|
||||
__host__ __device__ void Update(const DeviceSplitCandidate& other,
|
||||
const param_t& param) {
|
||||
if (other.loss_chg > loss_chg &&
|
||||
other.left_sum.GetHess() >= param.min_child_weight &&
|
||||
other.right_sum.GetHess() >= param.min_child_weight) {
|
||||
@@ -170,8 +191,10 @@ struct SumCallbackOp {
|
||||
};
|
||||
|
||||
template <typename gpair_t>
|
||||
__device__ inline float device_calc_loss_chg(
|
||||
const GPUTrainingParam& param, const gpair_t& left, const gpair_t& parent_sum, const float& parent_gain) {
|
||||
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
||||
const gpair_t& left,
|
||||
const gpair_t& parent_sum,
|
||||
const float& parent_gain) {
|
||||
gpair_t right = parent_sum - left;
|
||||
float left_gain = CalcGain(param, left.GetGrad(), left.GetHess());
|
||||
float right_gain = CalcGain(param, right.GetGrad(), right.GetHess());
|
||||
@@ -187,8 +210,8 @@ __device__ float inline loss_chg_missing(const gpair_t& scan,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float missing_left_loss =
|
||||
device_calc_loss_chg(param, scan + missing, parent_sum, parent_gain);
|
||||
float missing_right_loss = device_calc_loss_chg(
|
||||
param, scan, parent_sum, parent_gain);
|
||||
float missing_right_loss =
|
||||
device_calc_loss_chg(param, scan, parent_sum, parent_gain);
|
||||
|
||||
if (missing_left_loss >= missing_right_loss) {
|
||||
missing_left_out = true;
|
||||
|
||||
@@ -537,6 +537,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
device_gpair[d_idx].copy(gpair.begin() + device_row_segments[d_idx],
|
||||
gpair.begin() + device_row_segments[d_idx + 1]);
|
||||
|
||||
// Check gradients are within acceptable size range
|
||||
CheckGradientMax(device_gpair[d_idx]);
|
||||
|
||||
subsample_gpair(&device_gpair[d_idx], param.subsample,
|
||||
device_row_segments[d_idx]);
|
||||
|
||||
|
||||
@@ -334,6 +334,8 @@ struct DeviceShard {
|
||||
ridx_segments.front() = std::make_pair(0, ridx.size());
|
||||
this->gpair.copy(host_gpair.begin() + row_start_idx,
|
||||
host_gpair.begin() + row_end_idx);
|
||||
// Check gradients are within acceptable size range
|
||||
CheckGradientMax(gpair);
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
@@ -551,8 +553,8 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
|
||||
unsigned ballot = __ballot(val == left_nidx);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user