add max_delta_step

This commit is contained in:
tqchen 2015-03-26 09:47:16 -07:00
parent 149b43a0a8
commit 23e46b7fa5

View File

@ -28,6 +28,10 @@ struct TrainParam{
float reg_alpha;
// default direction choice
int default_direction;
// maximum delta update we can add in weight estimation
// this parameter can be used to stablize update
// default=0 means no constraint on weight delta
float max_delta_step;
// whether we want to do subsample
float subsample;
// whether to subsample columns each split, in each level
@ -52,6 +56,7 @@ struct TrainParam{
learning_rate = 0.3f;
min_split_loss = 0.0f;
min_child_weight = 1.0f;
max_delta_step = 0.0f;
max_depth = 6;
reg_lambda = 1.0f;
reg_alpha = 0.0f;
@ -81,6 +86,7 @@ struct TrainParam{
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
if (!strcmp(name, "max_delta_step")) max_delta_step = static_cast<float>(atof(val));
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast<float>(atof(val));
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
@ -102,11 +108,21 @@ struct TrainParam{
// calculate the cost of loss function
inline double CalcGain(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) return 0.0;
if (max_delta_step == 0.0f) {
if (reg_alpha == 0.0f) {
return Sqr(sum_grad) / (sum_hess + reg_lambda);
} else {
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
}
} else {
double w = CalcWeight(sum_grad, sum_hess);
double ret = sum_grad * w + 0.5 * (sum_hess + reg_lambda) * Sqr(w);
if (reg_alpha == 0.0f) {
return - 2.0 * ret;
} else {
return - 2.0 * (ret + reg_alpha * std::abs(w));
}
}
}
// calculate cost of loss function with four stati
inline double CalcGain(double sum_grad, double sum_hess,
@ -122,11 +138,17 @@ struct TrainParam{
// calculate weight given the statistics
inline double CalcWeight(double sum_grad, double sum_hess) const {
if (sum_hess < min_child_weight) return 0.0;
double dw;
if (reg_alpha == 0.0f) {
return -sum_grad / (sum_hess + reg_lambda);
dw = -sum_grad / (sum_hess + reg_lambda);
} else {
return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
dw = -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
}
if (max_delta_step != 0.0f) {
if (dw > max_delta_step) dw = max_delta_step;
if (dw < -max_delta_step) dw = -max_delta_step;
}
return dw;
}
/*! \brief whether need forward small to big search: default right */
inline bool need_forward_search(float col_density = 0.0f) const {