add max_delta_step
This commit is contained in:
parent
149b43a0a8
commit
23e46b7fa5
@ -27,7 +27,11 @@ struct TrainParam{
|
|||||||
// L1 regularization factor
|
// L1 regularization factor
|
||||||
float reg_alpha;
|
float reg_alpha;
|
||||||
// default direction choice
|
// default direction choice
|
||||||
int default_direction;
|
int default_direction;
|
||||||
|
// maximum delta update we can add in weight estimation
|
||||||
|
// this parameter can be used to stablize update
|
||||||
|
// default=0 means no constraint on weight delta
|
||||||
|
float max_delta_step;
|
||||||
// whether we want to do subsample
|
// whether we want to do subsample
|
||||||
float subsample;
|
float subsample;
|
||||||
// whether to subsample columns each split, in each level
|
// whether to subsample columns each split, in each level
|
||||||
@ -52,6 +56,7 @@ struct TrainParam{
|
|||||||
learning_rate = 0.3f;
|
learning_rate = 0.3f;
|
||||||
min_split_loss = 0.0f;
|
min_split_loss = 0.0f;
|
||||||
min_child_weight = 1.0f;
|
min_child_weight = 1.0f;
|
||||||
|
max_delta_step = 0.0f;
|
||||||
max_depth = 6;
|
max_depth = 6;
|
||||||
reg_lambda = 1.0f;
|
reg_lambda = 1.0f;
|
||||||
reg_alpha = 0.0f;
|
reg_alpha = 0.0f;
|
||||||
@ -81,6 +86,7 @@ struct TrainParam{
|
|||||||
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
|
if (!strcmp(name, "learning_rate")) learning_rate = static_cast<float>(atof(val));
|
||||||
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
|
if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast<float>(atof(val));
|
||||||
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
|
if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast<float>(atof(val));
|
||||||
|
if (!strcmp(name, "max_delta_step")) max_delta_step = static_cast<float>(atof(val));
|
||||||
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
|
if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast<float>(atof(val));
|
||||||
if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast<float>(atof(val));
|
if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast<float>(atof(val));
|
||||||
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
|
if (!strcmp(name, "subsample")) subsample = static_cast<float>(atof(val));
|
||||||
@ -102,10 +108,20 @@ struct TrainParam{
|
|||||||
// calculate the cost of loss function
|
// calculate the cost of loss function
|
||||||
inline double CalcGain(double sum_grad, double sum_hess) const {
|
inline double CalcGain(double sum_grad, double sum_hess) const {
|
||||||
if (sum_hess < min_child_weight) return 0.0;
|
if (sum_hess < min_child_weight) return 0.0;
|
||||||
if (reg_alpha == 0.0f) {
|
if (max_delta_step == 0.0f) {
|
||||||
return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
if (reg_alpha == 0.0f) {
|
||||||
|
return Sqr(sum_grad) / (sum_hess + reg_lambda);
|
||||||
|
} else {
|
||||||
|
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda);
|
double w = CalcWeight(sum_grad, sum_hess);
|
||||||
|
double ret = sum_grad * w + 0.5 * (sum_hess + reg_lambda) * Sqr(w);
|
||||||
|
if (reg_alpha == 0.0f) {
|
||||||
|
return - 2.0 * ret;
|
||||||
|
} else {
|
||||||
|
return - 2.0 * (ret + reg_alpha * std::abs(w));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// calculate cost of loss function with four stati
|
// calculate cost of loss function with four stati
|
||||||
@ -122,11 +138,17 @@ struct TrainParam{
|
|||||||
// calculate weight given the statistics
|
// calculate weight given the statistics
|
||||||
inline double CalcWeight(double sum_grad, double sum_hess) const {
|
inline double CalcWeight(double sum_grad, double sum_hess) const {
|
||||||
if (sum_hess < min_child_weight) return 0.0;
|
if (sum_hess < min_child_weight) return 0.0;
|
||||||
|
double dw;
|
||||||
if (reg_alpha == 0.0f) {
|
if (reg_alpha == 0.0f) {
|
||||||
return -sum_grad / (sum_hess + reg_lambda);
|
dw = -sum_grad / (sum_hess + reg_lambda);
|
||||||
} else {
|
} else {
|
||||||
return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
|
dw = -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda);
|
||||||
}
|
}
|
||||||
|
if (max_delta_step != 0.0f) {
|
||||||
|
if (dw > max_delta_step) dw = max_delta_step;
|
||||||
|
if (dw < -max_delta_step) dw = -max_delta_step;
|
||||||
|
}
|
||||||
|
return dw;
|
||||||
}
|
}
|
||||||
/*! \brief whether need forward small to big search: default right */
|
/*! \brief whether need forward small to big search: default right */
|
||||||
inline bool need_forward_search(float col_density = 0.0f) const {
|
inline bool need_forward_search(float col_density = 0.0f) const {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user