Support slicing tree model (#6302)

This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone.

* Implement the save_best option in early stopping.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-11-03 02:27:39 -05:00
committed by GitHub
parent 29745c6df2
commit 2cc9662005
19 changed files with 550 additions and 37 deletions

View File

@@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
*/
XGB_DLL int XGBoosterFree(BoosterHandle handle);
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
*
* \param handle Booster to be sliced.
* \param begin_layer start of the slice
* \param end_layer end of the slice; end_layer=0 is equivalent to
* end_layer=num_boost_round
* \param step step size of the slice
* \param out Sliced booster.
*
* \return 0 when success, -1 when failure happens, -2 when index is out of bound.
*/
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out);
/*!
* \brief set parameters
* \param handle handle

View File

@@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable {
* \param fo output stream
*/
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster
*/
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const {
LOG(FATAL) << "Slice is not supported by current booster.";
}
/*!
* \brief whether the model allow lazy checkpoint
* return true if model is only updated in DoBoost

View File

@@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \return whether the model allow lazy checkpoint in rabit.
*/
bool AllowLazyCheckPoint() const;
/*!
* \brief Slice the model.
*
* See InplacePredict for layer parameters.
*
* \param step step size between slice.
* \param out_of_bound Return true if end layer is out of bound.
*
* \return a sliced model.
*/
virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
bool *out_of_bound) = 0;
/*!
* \brief dump the model in the requested format
* \param fmap feature map that may help give interpretations of feature