[CORE] The update process for a tree model, and its application to feature importance (#1670)
* [CORE] allow updating trees in an existing model * [CORE] in refresh updater, allow keeping old leaf values and update stats only * [R-package] xgb.train mod to allow updating trees in an existing model * [R-package] added check for nrounds when is_update * [CORE] merge parameter declaration changes; unify their code style * [CORE] move the update-process trees initialization to Configure; rename default process_type to 'default'; fix the trees and trees_to_update sizes comparison check * [R-package] unit tests for the update process type * [DOC] documentation for process_type parameter; improved docs for updater, Gamma and Tweedie; added some parameter aliases; metrics indentation and some were non-documented * fix my sloppy merge conflict resolutions * [CORE] add a TreeProcessType enum * whitespace fix
This commit is contained in:
committed by
Tianqi Chen
parent
4398fbbe4a
commit
a44032d095
@@ -26,6 +26,12 @@ namespace gbm {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||
|
||||
// boosting process types
|
||||
enum TreeProcessType {
|
||||
kDefault,
|
||||
kUpdate
|
||||
};
|
||||
|
||||
/*! \brief training parameters */
|
||||
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
/*!
|
||||
@@ -35,13 +41,24 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
int num_parallel_tree;
|
||||
/*! \brief tree updater sequence */
|
||||
std::string updater_seq;
|
||||
/*! \brief type of boosting process to run */
|
||||
int process_type;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
||||
DMLC_DECLARE_FIELD(num_parallel_tree).set_lower_bound(1).set_default(1)
|
||||
DMLC_DECLARE_FIELD(num_parallel_tree)
|
||||
.set_default(1)
|
||||
.set_lower_bound(1)
|
||||
.describe("Number of parallel trees constructed during each iteration."\
|
||||
" This option is used to support boosted random forest");
|
||||
DMLC_DECLARE_FIELD(updater_seq).set_default("grow_colmaker,prune")
|
||||
DMLC_DECLARE_FIELD(updater_seq)
|
||||
.set_default("grow_colmaker,prune")
|
||||
.describe("Tree updater sequence.");
|
||||
DMLC_DECLARE_FIELD(process_type)
|
||||
.set_default(kDefault)
|
||||
.add_enum("default", kDefault)
|
||||
.add_enum("update", kUpdate)
|
||||
.describe("Whether to run the normal boosting process that creates new trees,"\
|
||||
" or to update the trees in an existing model.");
|
||||
// add alias
|
||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||
}
|
||||
@@ -63,21 +80,30 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
float learning_rate;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(DartTrainParam) {
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false)
|
||||
DMLC_DECLARE_FIELD(silent)
|
||||
.set_default(false)
|
||||
.describe("Not print information during training.");
|
||||
DMLC_DECLARE_FIELD(sample_type).set_default(0)
|
||||
DMLC_DECLARE_FIELD(sample_type)
|
||||
.set_default(0)
|
||||
.add_enum("uniform", 0)
|
||||
.add_enum("weighted", 1)
|
||||
.describe("Different types of sampling algorithm.");
|
||||
DMLC_DECLARE_FIELD(normalize_type).set_default(0)
|
||||
DMLC_DECLARE_FIELD(normalize_type)
|
||||
.set_default(0)
|
||||
.add_enum("tree", 0)
|
||||
.add_enum("forest", 1)
|
||||
.describe("Different types of normalization algorithm.");
|
||||
DMLC_DECLARE_FIELD(rate_drop).set_range(0.0f, 1.0f).set_default(0.0f)
|
||||
DMLC_DECLARE_FIELD(rate_drop)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Parameter of how many trees are dropped.");
|
||||
DMLC_DECLARE_FIELD(skip_drop).set_range(0.0f, 1.0f).set_default(0.0f)
|
||||
DMLC_DECLARE_FIELD(skip_drop)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Parameter of whether to drop trees.");
|
||||
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
|
||||
DMLC_DECLARE_FIELD(learning_rate)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.3f)
|
||||
.describe("Learning rate(step size) of update.");
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
}
|
||||
@@ -157,12 +183,21 @@ class GBTree : public GradientBooster {
|
||||
for (const auto& up : updaters) {
|
||||
up->Init(cfg);
|
||||
}
|
||||
// for the 'update' process_type, move trees into trees_to_update
|
||||
if (tparam.process_type == kUpdate && trees_to_update.size() == 0u) {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees_to_update.push_back(std::move(trees[i]));
|
||||
}
|
||||
trees.clear();
|
||||
mparam.num_trees = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||
<< "GBTree: invalid model file";
|
||||
trees.clear();
|
||||
trees_to_update.clear();
|
||||
for (int i = 0; i < mparam.num_trees; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->Load(fi);
|
||||
@@ -386,11 +421,20 @@ class GBTree : public GradientBooster {
|
||||
ret->clear();
|
||||
// create the trees
|
||||
for (int i = 0; i < tparam.num_parallel_tree; ++i) {
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->param.InitAllowUnknown(this->cfg);
|
||||
ptr->InitModel();
|
||||
new_trees.push_back(ptr.get());
|
||||
ret->push_back(std::move(ptr));
|
||||
if (tparam.process_type == kDefault) {
|
||||
// create new tree
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->param.InitAllowUnknown(this->cfg);
|
||||
ptr->InitModel();
|
||||
new_trees.push_back(ptr.get());
|
||||
ret->push_back(std::move(ptr));
|
||||
} else if (tparam.process_type == kUpdate) {
|
||||
CHECK_LT(trees.size(), trees_to_update.size());
|
||||
// move an existing tree from trees_to_update
|
||||
auto t = std::move(trees_to_update[trees.size()]);
|
||||
new_trees.push_back(t.get());
|
||||
ret->push_back(std::move(t));
|
||||
}
|
||||
}
|
||||
// update the trees
|
||||
for (auto& up : updaters) {
|
||||
@@ -493,6 +537,8 @@ class GBTree : public GradientBooster {
|
||||
GBTreeModelParam mparam;
|
||||
/*! \brief vector of trees stored in the model */
|
||||
std::vector<std::unique_ptr<RegTree> > trees;
|
||||
/*! \brief for the update process, a place to keep the initial trees */
|
||||
std::vector<std::unique_ptr<RegTree> > trees_to_update;
|
||||
/*! \brief some information indicator of the tree, reserved */
|
||||
std::vector<int> tree_info;
|
||||
// ----training fields----
|
||||
|
||||
@@ -64,6 +64,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
bool cache_opt;
|
||||
// whether to not print info during training.
|
||||
bool silent;
|
||||
// whether refresh updater needs to update the leaf values
|
||||
bool refresh_leaf;
|
||||
// auxiliary data structure
|
||||
std::vector<int> monotone_constraints;
|
||||
// declare the parameters
|
||||
@@ -75,10 +77,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(min_split_loss)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe(
|
||||
"Minimum loss reduction required to make a further partition.");
|
||||
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6).describe(
|
||||
"Maximum depth of the tree.");
|
||||
.describe("Minimum loss reduction required to make a further partition.");
|
||||
DMLC_DECLARE_FIELD(max_depth)
|
||||
.set_lower_bound(0)
|
||||
.set_default(6)
|
||||
.describe("Maximum depth of the tree.");
|
||||
DMLC_DECLARE_FIELD(min_child_weight)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(1.0f)
|
||||
@@ -100,9 +103,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(max_delta_step)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe(
|
||||
"Maximum delta step we allow each tree's weight estimate to be. "
|
||||
"If the value is set to 0, it means there is no constraint");
|
||||
.describe("Maximum delta step we allow each tree's weight estimate to be. "\
|
||||
"If the value is set to 0, it means there is no constraint");
|
||||
DMLC_DECLARE_FIELD(subsample)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(1.0f)
|
||||
@@ -114,8 +116,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(colsample_bytree)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(1.0f)
|
||||
.describe(
|
||||
"Subsample ratio of columns, resample on each tree construction.");
|
||||
.describe("Subsample ratio of columns, resample on each tree construction.");
|
||||
DMLC_DECLARE_FIELD(opt_dense_col)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(1.0f)
|
||||
@@ -127,8 +128,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(sketch_ratio)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(2.0f)
|
||||
.describe("EXP Param: Sketch accuracy related parameter of approximate "
|
||||
"algorithm.");
|
||||
.describe("EXP Param: Sketch accuracy related parameter of approximate algorithm.");
|
||||
DMLC_DECLARE_FIELD(size_leaf_vector)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
@@ -136,10 +136,15 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(parallel_option)
|
||||
.set_default(0)
|
||||
.describe("Different types of parallelization algorithm.");
|
||||
DMLC_DECLARE_FIELD(cache_opt).set_default(true).describe(
|
||||
"EXP Param: Cache aware optimization.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||
"Do not print information during training.");
|
||||
DMLC_DECLARE_FIELD(cache_opt)
|
||||
.set_default(true)
|
||||
.describe("EXP Param: Cache aware optimization.");
|
||||
DMLC_DECLARE_FIELD(silent)
|
||||
.set_default(false)
|
||||
.describe("Do not print information during trainig.");
|
||||
DMLC_DECLARE_FIELD(refresh_leaf)
|
||||
.set_default(true)
|
||||
.describe("Whether the refresh updater needs to update leaf values.");
|
||||
DMLC_DECLARE_FIELD(monotone_constraints)
|
||||
.set_default(std::vector<int>())
|
||||
.describe("Constraint of variable monotonicity");
|
||||
|
||||
@@ -134,7 +134,9 @@ class TreeRefresher: public TreeUpdater {
|
||||
tree.stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
|
||||
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
|
||||
if (tree[nid].is_leaf()) {
|
||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||
if (param.refresh_leaf) {
|
||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||
}
|
||||
} else {
|
||||
tree.stat(nid).loss_chg = static_cast<bst_float>(
|
||||
gstats[tree[nid].cleft()].CalcGain(param) +
|
||||
|
||||
Reference in New Issue
Block a user