diff --git a/src/tree/param.h b/src/tree/param.h index 41abba90f..3458a93a4 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -27,7 +27,11 @@ struct TrainParam{ // L1 regularization factor float reg_alpha; // default direction choice - int default_direction; + int default_direction; + // maximum delta update we can add in weight estimation + // this parameter can be used to stablize update + // default=0 means no constraint on weight delta + float max_delta_step; // whether we want to do subsample float subsample; // whether to subsample columns each split, in each level @@ -52,6 +56,7 @@ struct TrainParam{ learning_rate = 0.3f; min_split_loss = 0.0f; min_child_weight = 1.0f; + max_delta_step = 0.0f; max_depth = 6; reg_lambda = 1.0f; reg_alpha = 0.0f; @@ -81,6 +86,7 @@ struct TrainParam{ if (!strcmp(name, "learning_rate")) learning_rate = static_cast(atof(val)); if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast(atof(val)); if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast(atof(val)); + if (!strcmp(name, "max_delta_step")) max_delta_step = static_cast(atof(val)); if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast(atof(val)); if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast(atof(val)); if (!strcmp(name, "subsample")) subsample = static_cast(atof(val)); @@ -102,10 +108,20 @@ struct TrainParam{ // calculate the cost of loss function inline double CalcGain(double sum_grad, double sum_hess) const { if (sum_hess < min_child_weight) return 0.0; - if (reg_alpha == 0.0f) { - return Sqr(sum_grad) / (sum_hess + reg_lambda); + if (max_delta_step == 0.0f) { + if (reg_alpha == 0.0f) { + return Sqr(sum_grad) / (sum_hess + reg_lambda); + } else { + return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda); + } } else { - return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda); + double w = CalcWeight(sum_grad, sum_hess); + double ret = sum_grad * w + 0.5 * (sum_hess + reg_lambda) * Sqr(w); + if (reg_alpha == 0.0f) { + return - 2.0 * ret; + } else { + return - 2.0 * (ret + reg_alpha * std::abs(w)); + } } } // calculate cost of loss function with four stati @@ -122,11 +138,17 @@ struct TrainParam{ // calculate weight given the statistics inline double CalcWeight(double sum_grad, double sum_hess) const { if (sum_hess < min_child_weight) return 0.0; + double dw; if (reg_alpha == 0.0f) { - return -sum_grad / (sum_hess + reg_lambda); + dw = -sum_grad / (sum_hess + reg_lambda); } else { - return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda); + dw = -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda); } + if (max_delta_step != 0.0f) { + if (dw > max_delta_step) dw = max_delta_step; + if (dw < -max_delta_step) dw = -max_delta_step; + } + return dw; } /*! \brief whether need forward small to big search: default right */ inline bool need_forward_search(float col_density = 0.0f) const {