Support slicing tree model (#6302)
This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone. * Implement the save_best option in early stopping. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
29745c6df2
commit
2cc9662005
@ -7,9 +7,9 @@ package. In XGBoost 1.3, a new callback interface is designed for Python packag
|
||||
provides the flexiblity of designing various extension for training. Also, XGBoost has a
|
||||
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
|
||||
|
||||
#######################
|
||||
|
||||
Using builtin callbacks
|
||||
#######################
|
||||
-----------------------
|
||||
|
||||
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
|
||||
``verbose``/``verbose_eval``, when specified the training procedure will define the
|
||||
@ -50,9 +50,9 @@ this callback function directly into XGBoost:
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
|
||||
|
||||
##########################
|
||||
|
||||
Defining your own callback
|
||||
##########################
|
||||
--------------------------
|
||||
|
||||
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
|
||||
defined callbacks should inherit this class and override corresponding methods. There's a
|
||||
|
||||
@ -12,4 +12,5 @@ Contents
|
||||
python_intro
|
||||
python_api
|
||||
callbacks
|
||||
model
|
||||
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
|
||||
|
||||
38
doc/python/model.rst
Normal file
38
doc/python/model.rst
Normal file
@ -0,0 +1,38 @@
|
||||
#####
|
||||
Model
|
||||
#####
|
||||
|
||||
Slice tree model
|
||||
----------------
|
||||
|
||||
When ``booster`` is set to ``gbtree`` or ``dart``, XGBoost builds a tree model, which is a
|
||||
list of trees and can be sliced into multiple sub-models.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from sklearn.datasets import make_classification
|
||||
num_classes = 3
|
||||
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||
n_classes=num_classes)
|
||||
dtrain = xgb.DMatrix(data=X, label=y)
|
||||
num_parallel_tree = 4
|
||||
num_boost_round = 16
|
||||
# total number of built trees is num_parallel_tree * num_classes * num_boost_round
|
||||
|
||||
# We build a boosted random forest for classification here.
|
||||
booster = xgb.train({
|
||||
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3},
|
||||
num_boost_round=num_boost_round, dtrain=dtrain)
|
||||
|
||||
# This is the sliced model, containing [3, 7) forests
|
||||
# step is also supported with some limitations like negative step is invalid.
|
||||
sliced: xgb.Booster = booster[3:7]
|
||||
|
||||
# Access individual tree layer
|
||||
trees = [_ for _ in booster]
|
||||
assert len(trees) == num_boost_round
|
||||
|
||||
|
||||
The sliced model is a copy of selected trees, that means the model itself is immutable
|
||||
during slicing. This feature is the basis of `save_best` option in early stopping
|
||||
callback.
|
||||
@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
*/
|
||||
XGB_DLL int XGBoosterFree(BoosterHandle handle);
|
||||
|
||||
/*!
|
||||
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||
*
|
||||
* \param handle Booster to be sliced.
|
||||
* \param begin_layer start of the slice
|
||||
* \param end_layer end of the slice; end_layer=0 is equivalent to
|
||||
* end_layer=num_boost_round
|
||||
* \param step step size of the slice
|
||||
* \param out Sliced booster.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens, -2 when index is out of bound.
|
||||
*/
|
||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
||||
int end_layer, int step,
|
||||
BoosterHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief set parameters
|
||||
* \param handle handle
|
||||
|
||||
@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void Save(dmlc::Stream* fo) const = 0;
|
||||
/*!
|
||||
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \param out Output gradient booster
|
||||
*/
|
||||
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
GradientBooster *out, bool* out_of_bound) const {
|
||||
LOG(FATAL) << "Slice is not supported by current booster.";
|
||||
}
|
||||
/*!
|
||||
* \brief whether the model allow lazy checkpoint
|
||||
* return true if model is only updated in DoBoost
|
||||
|
||||
@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \return whether the model allow lazy checkpoint in rabit.
|
||||
*/
|
||||
bool AllowLazyCheckPoint() const;
|
||||
/*!
|
||||
* \brief Slice the model.
|
||||
*
|
||||
* See InplacePredict for layer parameters.
|
||||
*
|
||||
* \param step step size between slice.
|
||||
* \param out_of_bound Return true if end layer is out of bound.
|
||||
*
|
||||
* \return a sliced model.
|
||||
*/
|
||||
virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
||||
bool *out_of_bound) = 0;
|
||||
/*!
|
||||
* \brief dump the model in the requested format
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Callable, List
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from .core import EarlyStopException, CallbackEnv
|
||||
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
|
||||
from .compat import STRING_TYPES
|
||||
|
||||
|
||||
@ -279,9 +279,11 @@ class TrainingCallback(ABC):
|
||||
|
||||
def before_training(self, model):
|
||||
'''Run before training starts.'''
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
'''Run after training is finished.'''
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, evals_log):
|
||||
'''Run before each iteration. Return True when training should stop.'''
|
||||
@ -362,12 +364,24 @@ class CallbackContainer:
|
||||
def before_training(self, model):
|
||||
'''Function called before training.'''
|
||||
for c in self.callbacks:
|
||||
c.before_training(model=model)
|
||||
model = c.before_training(model=model)
|
||||
msg = 'before_training should return the model'
|
||||
if self.is_cv:
|
||||
assert isinstance(model.cvfolds, list), msg
|
||||
else:
|
||||
assert isinstance(model, Booster), msg
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
'''Function called after training.'''
|
||||
for c in self.callbacks:
|
||||
c.after_training(model)
|
||||
model = c.after_training(model=model)
|
||||
msg = 'after_training should return the model'
|
||||
if self.is_cv:
|
||||
assert isinstance(model.cvfolds, list), msg
|
||||
else:
|
||||
assert isinstance(model, Booster), msg
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Function called before training iteration.'''
|
||||
@ -461,7 +475,7 @@ class EarlyStopping(TrainingCallback):
|
||||
maximize : bool
|
||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||
save_best : bool
|
||||
Placeholder, the feature is not yet supported.
|
||||
Whether training should return the best model or the last model.
|
||||
'''
|
||||
def __init__(self,
|
||||
rounds,
|
||||
@ -473,9 +487,6 @@ class EarlyStopping(TrainingCallback):
|
||||
self.metric_name = metric_name
|
||||
self.rounds = rounds
|
||||
self.save_best = save_best
|
||||
# https://github.com/dmlc/xgboost/issues/5531
|
||||
assert self.save_best is False, 'save best is not yet supported.'
|
||||
|
||||
self.maximize = maximize
|
||||
self.stopping_history = {}
|
||||
|
||||
@ -525,7 +536,7 @@ class EarlyStopping(TrainingCallback):
|
||||
return True
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
def after_iteration(self, model: Booster, epoch, evals_log):
|
||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||
assert len(evals_log.keys()) >= 1, msg
|
||||
data_name = ''
|
||||
@ -551,6 +562,14 @@ class EarlyStopping(TrainingCallback):
|
||||
score = data_log[metric_name][-1]
|
||||
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||
|
||||
def after_training(self, model: Booster):
|
||||
try:
|
||||
if self.save_best:
|
||||
model = model[: int(model.attr('best_iteration'))]
|
||||
except XGBoostError as e:
|
||||
raise XGBoostError('`save_best` is not applicable to current booster') from e
|
||||
return model
|
||||
|
||||
|
||||
class EvaluationMonitor(TrainingCallback):
|
||||
'''Print the evaluation result at each iteration.
|
||||
@ -684,9 +703,11 @@ class LegacyCallbacks:
|
||||
|
||||
def before_training(self, model):
|
||||
'''Nothing to do for legacy callbacks'''
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
'''Nothing to do for legacy callbacks'''
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Called before each iteration.'''
|
||||
|
||||
@ -944,8 +944,8 @@ class Booster(object):
|
||||
Parameters for boosters.
|
||||
cache : list
|
||||
List of cache items.
|
||||
model_file : string or os.PathLike
|
||||
Path to the model file.
|
||||
model_file : string/os.PathLike/Booster/bytearray
|
||||
Path to the model file if it's string or PathLike.
|
||||
"""
|
||||
for d in cache:
|
||||
if not isinstance(d, DMatrix):
|
||||
@ -1021,6 +1021,43 @@ class Booster(object):
|
||||
state['handle'] = handle
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getitem__(self, val):
|
||||
if isinstance(val, int):
|
||||
val = slice(val, val+1)
|
||||
if isinstance(val, tuple):
|
||||
raise ValueError('Only supports slicing through 1 dimension.')
|
||||
if not isinstance(val, slice):
|
||||
msg = _expect((int, slice), type(val))
|
||||
raise TypeError(msg)
|
||||
if isinstance(val.start, type(Ellipsis)) or val.start is None:
|
||||
start = 0
|
||||
else:
|
||||
start = val.start
|
||||
if isinstance(val.stop, type(Ellipsis)) or val.stop is None:
|
||||
stop = 0
|
||||
else:
|
||||
stop = val.stop
|
||||
if stop < start:
|
||||
raise ValueError('Invalid slice', val)
|
||||
|
||||
step = val.step if val.step is not None else 1
|
||||
|
||||
start = ctypes.c_int(start)
|
||||
stop = ctypes.c_int(stop)
|
||||
step = ctypes.c_int(step)
|
||||
|
||||
sliced_handle = ctypes.c_void_p()
|
||||
status = _LIB.XGBoosterSlice(self.handle, start, stop, step,
|
||||
ctypes.byref(sliced_handle))
|
||||
if status == -2:
|
||||
raise IndexError('Layer index out of range')
|
||||
_check_call(status)
|
||||
|
||||
sliced = Booster()
|
||||
_check_call(_LIB.XGBoosterFree(sliced.handle))
|
||||
sliced.handle = sliced_handle
|
||||
return sliced
|
||||
|
||||
def save_config(self):
|
||||
'''Output internal parameter configuration of Booster as a JSON
|
||||
string.
|
||||
|
||||
@ -103,7 +103,7 @@ def _train_internal(params, dtrain,
|
||||
num_boost_round, feval, evals_result, callbacks,
|
||||
show_stdv=False, cvfolds=None)
|
||||
|
||||
callbacks.before_training(bst)
|
||||
bst = callbacks.before_training(bst)
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
@ -125,7 +125,7 @@ def _train_internal(params, dtrain,
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
callbacks.after_training(bst)
|
||||
bst = callbacks.after_training(bst)
|
||||
|
||||
if evals_result is not None and is_new_callback:
|
||||
evals_result.update(callbacks.history)
|
||||
@ -495,9 +495,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
verbose_eval, early_stopping_rounds, maximize, 0,
|
||||
num_boost_round, feval, None, callbacks,
|
||||
show_stdv=show_stdv, cvfolds=cvfolds)
|
||||
callbacks.before_training(cvfolds)
|
||||
|
||||
booster = _PackedBooster(cvfolds)
|
||||
callbacks.before_training(booster)
|
||||
|
||||
for i in range(num_boost_round):
|
||||
if callbacks.before_iteration(booster, i, dtrain, None):
|
||||
@ -524,4 +523,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
results = pd.DataFrame.from_dict(results)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
callbacks.after_training(booster)
|
||||
|
||||
return results
|
||||
|
||||
@ -730,6 +730,22 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
||||
int end_layer, int step,
|
||||
BoosterHandle *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* learner = static_cast<Learner*>(handle);
|
||||
bool out_of_bound = false;
|
||||
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
||||
if (out_of_bound) {
|
||||
return -2;
|
||||
}
|
||||
CHECK(p_out);
|
||||
*out = p_out;
|
||||
API_END();
|
||||
}
|
||||
|
||||
inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap,
|
||||
int with_stats, const char *format,
|
||||
xgboost::bst_ulong *len,
|
||||
|
||||
@ -398,6 +398,38 @@ void GBTree::SaveModel(Json* p_out) const {
|
||||
model_.SaveModel(&model);
|
||||
}
|
||||
|
||||
void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
GradientBooster *out, bool* out_of_bound) const {
|
||||
CHECK(configured_);
|
||||
CHECK(out);
|
||||
|
||||
auto p_gbtree = dynamic_cast<GBTree *>(out);
|
||||
CHECK(p_gbtree);
|
||||
GBTreeModel &out_model = p_gbtree->model_;
|
||||
auto layer_trees = this->LayerTrees();
|
||||
|
||||
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
|
||||
CHECK_GE(layer_end, layer_begin);
|
||||
CHECK_GE(step, 1);
|
||||
int32_t n_layers = (layer_end - layer_begin) / step;
|
||||
std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees;
|
||||
out_trees.resize(layer_trees * n_layers);
|
||||
std::vector<int32_t> &out_trees_info = out_model.tree_info;
|
||||
out_trees_info.resize(layer_trees * n_layers);
|
||||
out_model.param.num_trees = out_model.trees.size();
|
||||
CHECK(this->model_.trees_to_update.empty());
|
||||
|
||||
*out_of_bound = detail::SliceTrees(
|
||||
layer_begin, layer_end, step, this->model_, tparam_, layer_trees,
|
||||
[&](auto const &in_it, auto const &out_it) {
|
||||
auto new_tree =
|
||||
std::make_unique<RegTree>(*this->model_.trees.at(in_it));
|
||||
bst_group_t group = this->model_.tree_info[in_it];
|
||||
out_trees.at(out_it) = std::move(new_tree);
|
||||
out_trees_info.at(out_it) = group;
|
||||
});
|
||||
}
|
||||
|
||||
void GBTree::PredictBatch(DMatrix* p_fmat,
|
||||
PredictionCacheEntry* out_preds,
|
||||
bool,
|
||||
@ -494,6 +526,22 @@ class Dart : public GBTree {
|
||||
dparam_.UpdateAllowUnknown(cfg);
|
||||
}
|
||||
|
||||
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
GradientBooster *out, bool* out_of_bound) const final {
|
||||
GBTree::Slice(layer_begin, layer_end, step, out, out_of_bound);
|
||||
if (*out_of_bound) {
|
||||
return;
|
||||
}
|
||||
auto p_dart = dynamic_cast<Dart*>(out);
|
||||
CHECK(p_dart);
|
||||
CHECK(p_dart->weight_drop_.empty());
|
||||
detail::SliceTrees(
|
||||
layer_begin, layer_end, step, model_, tparam_, this->LayerTrees(),
|
||||
[&](auto const& in_it, auto const&) {
|
||||
p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it));
|
||||
});
|
||||
}
|
||||
|
||||
void SaveModel(Json *p_out) const override {
|
||||
auto &out = *p_out;
|
||||
out["name"] = String("dart");
|
||||
|
||||
@ -152,6 +152,50 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// From here on, layer becomes concrete trees.
|
||||
inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
|
||||
GBTreeTrainParam const &tparam,
|
||||
size_t layer_begin,
|
||||
size_t layer_end) {
|
||||
bst_group_t groups = model.learner_model_param->num_output_group;
|
||||
uint32_t tree_begin = layer_begin * groups * tparam.num_parallel_tree;
|
||||
uint32_t tree_end = layer_end * groups * tparam.num_parallel_tree;
|
||||
if (tree_end == 0) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
CHECK_LT(tree_begin, tree_end);
|
||||
return {tree_begin, tree_end};
|
||||
}
|
||||
|
||||
// Call fn for each pair of input output tree. Return true if index is out of bound.
|
||||
template <typename Func>
|
||||
inline bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
GBTreeModel const &model, GBTreeTrainParam const &tparam,
|
||||
uint32_t layer_trees, Func fn) {
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model, tparam, layer_begin, layer_end);
|
||||
if (tree_end > model.trees.size()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
layer_end = layer_end == 0 ? model.trees.size() / layer_trees : layer_end;
|
||||
uint32_t n_layers = (layer_end - layer_begin) / step;
|
||||
int32_t in_it = tree_begin;
|
||||
int32_t out_it = 0;
|
||||
for (uint32_t l = 0; l < n_layers; ++l) {
|
||||
for (uint32_t i = 0; i < layer_trees; ++i) {
|
||||
CHECK_LT(in_it, tree_end);
|
||||
fn(in_it, out_it);
|
||||
out_it++;
|
||||
in_it++;
|
||||
}
|
||||
in_it += (step - 1) * layer_trees;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// gradient boosted trees
|
||||
class GBTree : public GradientBooster {
|
||||
public:
|
||||
@ -200,6 +244,15 @@ class GBTree : public GradientBooster {
|
||||
return model_.learner_model_param->num_output_group == 1;
|
||||
}
|
||||
|
||||
// Number of trees per layer.
|
||||
auto LayerTrees() const {
|
||||
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
|
||||
return n_trees;
|
||||
}
|
||||
// slice the trees, out must be already allocated
|
||||
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
GradientBooster *out, bool* out_of_bound) const override;
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
PredictionCacheEntry* out_preds,
|
||||
bool training,
|
||||
@ -210,13 +263,8 @@ class GBTree : public GradientBooster {
|
||||
uint32_t layer_begin,
|
||||
unsigned layer_end) const override {
|
||||
CHECK(configured_);
|
||||
// From here on, layer becomes concrete trees.
|
||||
bst_group_t groups = model_.learner_model_param->num_output_group;
|
||||
uint32_t tree_begin = layer_begin * groups * tparam_.num_parallel_tree;
|
||||
uint32_t tree_end = layer_end * groups * tparam_.num_parallel_tree;
|
||||
if (tree_end == 0 || tree_end > model_.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model_.trees.size());
|
||||
}
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
}
|
||||
|
||||
@ -6,10 +6,10 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "gbtree_model.h"
|
||||
#include "gbtree.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
void GBTreeModel::Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2019 by Contributors
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* \file gbtree_model.h
|
||||
*/
|
||||
#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
|
||||
@ -22,6 +22,7 @@ namespace xgboost {
|
||||
class Json;
|
||||
|
||||
namespace gbm {
|
||||
|
||||
/*! \brief model parameters */
|
||||
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||
public:
|
||||
|
||||
@ -971,6 +971,26 @@ class LearnerImpl : public LearnerIO {
|
||||
return gbm_->DumpModel(fmap, with_stats, format);
|
||||
}
|
||||
|
||||
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
||||
bool *out_of_bound) override {
|
||||
this->Configure();
|
||||
CHECK_GE(begin_layer, 0);
|
||||
auto *out_impl = new LearnerImpl({});
|
||||
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
|
||||
this->tparam_.booster, &this->generic_parameters_,
|
||||
&this->learner_model_param_));
|
||||
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
|
||||
out_impl->gbm_ = std::move(gbm);
|
||||
Json config { Object() };
|
||||
this->SaveConfig(&config);
|
||||
out_impl->mparam_ = this->mparam_;
|
||||
out_impl->attributes_ = this->attributes_;
|
||||
out_impl->learner_model_param_ = this->learner_model_param_;
|
||||
out_impl->LoadConfig(config);
|
||||
out_impl->Configure();
|
||||
return out_impl;
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
||||
monitor_.Start("UpdateOneIter");
|
||||
TrainingObserver::Instance().Update(iter);
|
||||
|
||||
@ -154,9 +154,9 @@ TEST(GBTree, JsonIO) {
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
||||
|
||||
auto const& gbtree_model = model["model"]["model"];
|
||||
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
|
||||
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1ul);
|
||||
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
|
||||
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
|
||||
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1ul);
|
||||
|
||||
auto j_train_param = model["config"]["gbtree_train_param"];
|
||||
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
|
||||
@ -194,7 +194,7 @@ TEST(Dart, JsonIO) {
|
||||
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
|
||||
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
|
||||
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0ul);
|
||||
}
|
||||
|
||||
TEST(Dart, Prediction) {
|
||||
@ -230,4 +230,122 @@ TEST(Dart, Prediction) {
|
||||
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Json, Json> TestModelSlice(std::string booster) {
|
||||
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses);
|
||||
|
||||
int32_t kIters = 10;
|
||||
std::unique_ptr<Learner> learner {
|
||||
Learner::Create({m})
|
||||
};
|
||||
learner->SetParams(Args{{"booster", booster},
|
||||
{"tree_method", "hist"},
|
||||
{"num_parallel_tree", std::to_string(kForest)},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"subsample", "0.5"},
|
||||
{"max_depth", "2"}});
|
||||
|
||||
for (auto i = 0; i < kIters; ++i) {
|
||||
learner->UpdateOneIter(i, m);
|
||||
}
|
||||
|
||||
Json model{Object()};
|
||||
Json config{Object()};
|
||||
learner->SaveModel(&model);
|
||||
learner->SaveConfig(&config);
|
||||
bool out_of_bound = false;
|
||||
|
||||
size_t constexpr kSliceStart = 2, kSliceEnd = 8, kStep = 3;
|
||||
std::unique_ptr<Learner> sliced {learner->Slice(kSliceStart, kSliceEnd, kStep, &out_of_bound)};
|
||||
Json sliced_model{Object()};
|
||||
sliced->SaveModel(&sliced_model);
|
||||
|
||||
auto get_shape = [&](Json const& model) {
|
||||
if (booster == "gbtree") {
|
||||
return get<Object const>(model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]);
|
||||
} else {
|
||||
return get<Object const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["gbtree_model_param"]);
|
||||
}
|
||||
};
|
||||
|
||||
auto const& model_shape = get_shape(sliced_model);
|
||||
CHECK_EQ(get<String const>(model_shape.at("num_trees")), std::to_string(2 * kClasses * kForest));
|
||||
|
||||
Json sliced_config {Object()};
|
||||
sliced->SaveConfig(&sliced_config);
|
||||
CHECK_EQ(sliced_config, config);
|
||||
|
||||
auto get_trees = [&](Json const& model) {
|
||||
if (booster == "gbtree") {
|
||||
return get<Array const>(model["learner"]["gradient_booster"]["model"]["trees"]);
|
||||
} else {
|
||||
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
|
||||
}
|
||||
};
|
||||
|
||||
auto get_info = [&](Json const& model) {
|
||||
if (booster == "gbtree") {
|
||||
return get<Array const>(model["learner"]["gradient_booster"]["model"]["tree_info"]);
|
||||
} else {
|
||||
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["tree_info"]);
|
||||
}
|
||||
};
|
||||
|
||||
auto const &sliced_trees = get_trees(sliced_model);
|
||||
CHECK_EQ(sliced_trees.size(), 2 * kClasses * kForest);
|
||||
|
||||
auto constexpr kLayerSize = kClasses * kForest;
|
||||
auto const &sliced_info = get_info(sliced_model);
|
||||
|
||||
for (size_t layer = 0; layer < 2; ++layer) {
|
||||
for (size_t j = 0; j < kClasses; ++j) {
|
||||
for (size_t k = 0; k < kForest; ++k) {
|
||||
auto idx = layer * kLayerSize + j * kForest + k;
|
||||
auto const &group = get<Integer const>(sliced_info.at(idx));
|
||||
CHECK_EQ(static_cast<size_t>(group), j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto const& trees = get_trees(model);
|
||||
|
||||
// Sliced layers are [2, 5]
|
||||
auto begin = kLayerSize * kSliceStart;
|
||||
auto end = begin + kLayerSize;
|
||||
auto j = 0;
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
Json tree = trees[i];
|
||||
tree["id"] = Integer(0); // id is different, we set it to 0 to allow comparison.
|
||||
auto sliced_tree = sliced_trees[j];
|
||||
sliced_tree["id"] = Integer(0);
|
||||
CHECK_EQ(tree, sliced_tree);
|
||||
j++;
|
||||
}
|
||||
|
||||
begin = kLayerSize * (kSliceStart + kStep);
|
||||
end = begin + kLayerSize;
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
Json tree = trees[i];
|
||||
tree["id"] = Integer(0);
|
||||
auto sliced_tree = sliced_trees[j];
|
||||
sliced_tree["id"] = Integer(0);
|
||||
CHECK_EQ(tree, sliced_tree);
|
||||
j++;
|
||||
}
|
||||
|
||||
return std::make_pair(model, sliced_model);
|
||||
}
|
||||
|
||||
TEST(GBTree, Slice) {
|
||||
TestModelSlice("gbtree");
|
||||
}
|
||||
|
||||
TEST(Dart, Slice) {
|
||||
Json model, sliced_model;
|
||||
std::tie(model, sliced_model) = TestModelSlice("dart");
|
||||
auto const& weights = get<Array const>(model["learner"]["gradient_booster"]["weight_drop"]);
|
||||
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
|
||||
ASSERT_EQ(weights.size(), trees.size());
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -118,7 +118,7 @@ TEST(Learner, Configuration) {
|
||||
|
||||
// eval_metric is not part of configuration
|
||||
auto attr_names = learner->GetConfigurationArguments();
|
||||
ASSERT_EQ(attr_names.size(), 1);
|
||||
ASSERT_EQ(attr_names.size(), 1ul);
|
||||
ASSERT_EQ(attr_names.find(emetric), attr_names.cend());
|
||||
ASSERT_EQ(attr_names.at("foo"), "bar");
|
||||
}
|
||||
@ -127,7 +127,7 @@ TEST(Learner, Configuration) {
|
||||
std::unique_ptr<Learner> learner { Learner::Create({nullptr}) };
|
||||
learner->SetParams({{"foo", "bar"}, {emetric, "auc"}, {emetric, "entropy"}, {emetric, "KL"}});
|
||||
auto attr_names = learner->GetConfigurationArguments();
|
||||
ASSERT_EQ(attr_names.size(), 1);
|
||||
ASSERT_EQ(attr_names.size(), 1ul);
|
||||
ASSERT_EQ(attr_names.at("foo"), "bar");
|
||||
}
|
||||
}
|
||||
@ -181,7 +181,7 @@ TEST(Learner, JsonModelIO) {
|
||||
learner->SaveModel(&new_in);
|
||||
|
||||
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
|
||||
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
|
||||
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1ul);
|
||||
ASSERT_EQ(out, new_in);
|
||||
}
|
||||
}
|
||||
@ -333,5 +333,4 @@ TEST(Learner, Seed) {
|
||||
ASSERT_EQ(std::to_string(seed),
|
||||
get<String>(config["learner"]["generic_param"]["seed"]));
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -29,7 +29,7 @@ def json_model(model_path, parameters):
|
||||
return model
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
class TestModels:
|
||||
def test_glm(self):
|
||||
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
||||
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
|
||||
@ -209,12 +209,14 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
bst = xgb.train([], dm1)
|
||||
bst.predict(dm1) # success
|
||||
self.assertRaises(ValueError, bst.predict, dm2)
|
||||
with pytest.raises(ValueError):
|
||||
bst.predict(dm2)
|
||||
bst.predict(dm1) # success
|
||||
|
||||
bst = xgb.train([], dm2)
|
||||
bst.predict(dm2) # success
|
||||
self.assertRaises(ValueError, bst.predict, dm1)
|
||||
with pytest.raises(ValueError):
|
||||
bst.predict(dm1)
|
||||
bst.predict(dm2) # success
|
||||
|
||||
def test_model_binary_io(self):
|
||||
@ -325,3 +327,96 @@ class TestModels(unittest.TestCase):
|
||||
parameters = {'tree_method': 'hist', 'booster': 'dart',
|
||||
'objective': 'multi:softmax'}
|
||||
validate_model(parameters)
|
||||
|
||||
@pytest.mark.parametrize('booster', ['gbtree', 'dart'])
|
||||
def test_slice(self, booster):
|
||||
from sklearn.datasets import make_classification
|
||||
num_classes = 3
|
||||
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||
n_classes=num_classes)
|
||||
dtrain = xgb.DMatrix(data=X, label=y)
|
||||
num_parallel_tree = 4
|
||||
num_boost_round = 16
|
||||
total_trees = num_parallel_tree * num_classes * num_boost_round
|
||||
booster = xgb.train({
|
||||
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
|
||||
'objective': 'multi:softprob'},
|
||||
num_boost_round=num_boost_round, dtrain=dtrain)
|
||||
assert len(booster.get_dump()) == total_trees
|
||||
beg = 3
|
||||
end = 7
|
||||
sliced: xgb.Booster = booster[beg: end]
|
||||
|
||||
sliced_trees = (end - beg) * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced_trees = sliced_trees // 2
|
||||
sliced: xgb.Booster = booster[beg: end: 2]
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced: xgb.Booster = booster[beg: ...]
|
||||
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced: xgb.Booster = booster[beg:]
|
||||
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced: xgb.Booster = booster[:end]
|
||||
sliced_trees = end * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
sliced: xgb.Booster = booster[...:end]
|
||||
sliced_trees = end * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
with pytest.raises(ValueError, match=r'>= 0'):
|
||||
booster[-1: 0]
|
||||
|
||||
# we do not accept empty slice.
|
||||
with pytest.raises(ValueError):
|
||||
booster[1:1]
|
||||
# stop can not be smaller than begin
|
||||
with pytest.raises(ValueError, match=r'Invalid.*'):
|
||||
booster[3:0]
|
||||
with pytest.raises(ValueError, match=r'Invalid.*'):
|
||||
booster[3:-1]
|
||||
# negative step is not supported.
|
||||
with pytest.raises(ValueError, match=r'.*>= 1.*'):
|
||||
booster[0:2:-1]
|
||||
# step can not be 0.
|
||||
with pytest.raises(ValueError, match=r'.*>= 1.*'):
|
||||
booster[0:2:0]
|
||||
|
||||
trees = [_ for _ in booster]
|
||||
assert len(trees) == num_boost_round
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
booster["wrong type"]
|
||||
with pytest.raises(IndexError):
|
||||
booster[:num_boost_round+1]
|
||||
with pytest.raises(ValueError):
|
||||
booster[1, 2] # too many dims
|
||||
# setitem is not implemented as model is immutable during slicing.
|
||||
with pytest.raises(TypeError):
|
||||
booster[...:end] = booster
|
||||
|
||||
sliced_0 = booster[1:3]
|
||||
sliced_1 = booster[3:7]
|
||||
|
||||
predt_0 = sliced_0.predict(dtrain, output_margin=True)
|
||||
predt_1 = sliced_1.predict(dtrain, output_margin=True)
|
||||
|
||||
merged = predt_0 + predt_1 - 0.5 # base score.
|
||||
single = booster[1:7].predict(dtrain, output_margin=True)
|
||||
np.testing.assert_allclose(merged, single, atol=1e-6)
|
||||
|
||||
sliced_0 = booster[1:7:2] # 1,3,5
|
||||
sliced_1 = booster[2:8:2] # 2,4,6
|
||||
|
||||
predt_0 = sliced_0.predict(dtrain, output_margin=True)
|
||||
predt_1 = sliced_1.predict(dtrain, output_margin=True)
|
||||
|
||||
merged = predt_0 + predt_1 - 0.5
|
||||
single = booster[1:7].predict(dtrain, output_margin=True)
|
||||
np.testing.assert_allclose(merged, single, atol=1e-6)
|
||||
|
||||
@ -113,6 +113,35 @@ class TestCallbacks(unittest.TestCase):
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
def test_early_stopping_save_best_model(self):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
n_estimators = 100
|
||||
cls = xgb.XGBClassifier(n_estimators=n_estimators)
|
||||
early_stopping_rounds = 5
|
||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)
|
||||
cls.fit(X, y, eval_set=[(X, y)],
|
||||
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
|
||||
booster = cls.get_booster()
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) == booster.best_iteration
|
||||
|
||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)
|
||||
cls = xgb.XGBClassifier(booster='gblinear', n_estimators=10)
|
||||
self.assertRaises(ValueError, lambda: cls.fit(X, y, eval_set=[(X, y)],
|
||||
eval_metric=tm.eval_error_metric,
|
||||
callbacks=[early_stop]))
|
||||
|
||||
# No error
|
||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=False)
|
||||
xgb.XGBClassifier(booster='gblinear', n_estimators=10).fit(
|
||||
X, y, eval_set=[(X, y)],
|
||||
eval_metric=tm.eval_error_metric,
|
||||
callbacks=[early_stop])
|
||||
|
||||
def run_eta_decay(self, tree_method, deprecated_callback):
|
||||
if deprecated_callback:
|
||||
scheduler = xgb.callback.reset_learning_rate
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user