add max_delta_step
This commit is contained in:
parent
149b43a0a8
commit
23e46b7fa5
@ -27,7 +27,11 @@ struct TrainParam{
|
||||
// L1 regularization factor
|
||||
float reg_alpha;
|
||||
// default direction choice
|
||||
int default_direction;
|
||||
int default_direction;
|
||||
// maximum delta update we can add in weight estimation
|
||||
// this parameter can be used to stablize update
|
||||
// default=0 means no constraint on weight delta
|
||||
float max_delta_step;
|
||||
// whether we want to do subsample
|
||||
float subsample;
|
||||
// whether to subsample columns each split, in each level
|
||||
@ -52,6 +56,7 @@ struct TrainParam{
|
||||
learning_rate = 0.3f;
|
||||
min_split_loss = 0.0f;
|
||||
min_child_weight = 1.0f;
|
||||
max_delta_step = 0.0f;
|
||||
max_depth = 6;
|
||||
reg_lambda = 1.0f;
|
||||
reg_alpha = 0.0f;
|
||||
@ -81,6 +86,7 @@ struct TrainParam{
|
||||
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "max_delta_step")) max_delta_step = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast<float>(atof(val));
|
||||
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
|
||||
@ -102,10 +108,20 @@ struct TrainParam{
|
||||
// calculate the cost of loss function
|
||||
inline double CalcGain(double sum_grad, double sum_hess) const {
|
||||
if (sum_hess < min_child_weight) return 0.0;
|
||||
if (reg_alpha == 0.0f) {
|
||||
return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
||||
if (max_delta_step == 0.0f) {
|
||||
if (reg_alpha == 0.0f) {
|
||||
return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
||||
} else {
|
||||
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
|
||||
}
|
||||
} else {
|
||||
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
|
||||
double w = CalcWeight(sum_grad, sum_hess);
|
||||
double ret = sum_grad * w + 0.5 * (sum_hess + reg_lambda) * Sqr(w);
|
||||
if (reg_alpha == 0.0f) {
|
||||
return - 2.0 * ret;
|
||||
} else {
|
||||
return - 2.0 * (ret + reg_alpha * std::abs(w));
|
||||
}
|
||||
}
|
||||
}
|
||||
// calculate cost of loss function with four stati
|
||||
@ -122,11 +138,17 @@ struct TrainParam{
|
||||
// calculate weight given the statistics
|
||||
inline double CalcWeight(double sum_grad, double sum_hess) const {
|
||||
if (sum_hess < min_child_weight) return 0.0;
|
||||
double dw;
|
||||
if (reg_alpha == 0.0f) {
|
||||
return -sum_grad / (sum_hess + reg_lambda);
|
||||
dw = -sum_grad / (sum_hess + reg_lambda);
|
||||
} else {
|
||||
return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
|
||||
dw = -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
|
||||
}
|
||||
if (max_delta_step != 0.0f) {
|
||||
if (dw > max_delta_step) dw = max_delta_step;
|
||||
if (dw < -max_delta_step) dw = -max_delta_step;
|
||||
}
|
||||
return dw;
|
||||
}
|
||||
/*! \brief whether need forward small to big search: default right */
|
||||
inline bool need_forward_search(float col_density = 0.0f) const {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user