Support slicing tree model (#6302)
This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone. * Implement the save_best option in early stopping. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
*/
|
||||
XGB_DLL int XGBoosterFree(BoosterHandle handle);
|
||||
|
||||
/*!
|
||||
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||
*
|
||||
* \param handle Booster to be sliced.
|
||||
* \param begin_layer start of the slice
|
||||
* \param end_layer end of the slice; end_layer=0 is equivalent to
|
||||
* end_layer=num_boost_round
|
||||
* \param step step size of the slice
|
||||
* \param out Sliced booster.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens, -2 when index is out of bound.
|
||||
*/
|
||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
||||
int end_layer, int step,
|
||||
BoosterHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief set parameters
|
||||
* \param handle handle
|
||||
|
||||
@@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void Save(dmlc::Stream* fo) const = 0;
|
||||
/*!
|
||||
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \param out Output gradient booster
|
||||
*/
|
||||
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
GradientBooster *out, bool* out_of_bound) const {
|
||||
LOG(FATAL) << "Slice is not supported by current booster.";
|
||||
}
|
||||
/*!
|
||||
* \brief whether the model allow lazy checkpoint
|
||||
* return true if model is only updated in DoBoost
|
||||
|
||||
@@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \return whether the model allow lazy checkpoint in rabit.
|
||||
*/
|
||||
bool AllowLazyCheckPoint() const;
|
||||
/*!
|
||||
* \brief Slice the model.
|
||||
*
|
||||
* See InplacePredict for layer parameters.
|
||||
*
|
||||
* \param step step size between slice.
|
||||
* \param out_of_bound Return true if end layer is out of bound.
|
||||
*
|
||||
* \return a sliced model.
|
||||
*/
|
||||
virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
||||
bool *out_of_bound) = 0;
|
||||
/*!
|
||||
* \brief dump the model in the requested format
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
|
||||
Reference in New Issue
Block a user