Support slicing tree model (#6302)
This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone. * Implement the save_best option in early stopping. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
29745c6df2
commit
2cc9662005
@ -7,9 +7,9 @@ package. In XGBoost 1.3, a new callback interface is designed for Python packag
|
|||||||
provides the flexiblity of designing various extension for training. Also, XGBoost has a
|
provides the flexiblity of designing various extension for training. Also, XGBoost has a
|
||||||
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
|
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
|
||||||
|
|
||||||
#######################
|
|
||||||
Using builtin callbacks
|
Using builtin callbacks
|
||||||
#######################
|
-----------------------
|
||||||
|
|
||||||
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
|
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
|
||||||
``verbose``/``verbose_eval``, when specified the training procedure will define the
|
``verbose``/``verbose_eval``, when specified the training procedure will define the
|
||||||
@ -50,9 +50,9 @@ this callback function directly into XGBoost:
|
|||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
|
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
|
||||||
|
|
||||||
##########################
|
|
||||||
Defining your own callback
|
Defining your own callback
|
||||||
##########################
|
--------------------------
|
||||||
|
|
||||||
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
|
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
|
||||||
defined callbacks should inherit this class and override corresponding methods. There's a
|
defined callbacks should inherit this class and override corresponding methods. There's a
|
||||||
|
|||||||
@ -12,4 +12,5 @@ Contents
|
|||||||
python_intro
|
python_intro
|
||||||
python_api
|
python_api
|
||||||
callbacks
|
callbacks
|
||||||
|
model
|
||||||
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
|
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
|
||||||
|
|||||||
38
doc/python/model.rst
Normal file
38
doc/python/model.rst
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#####
|
||||||
|
Model
|
||||||
|
#####
|
||||||
|
|
||||||
|
Slice tree model
|
||||||
|
----------------
|
||||||
|
|
||||||
|
When ``booster`` is set to ``gbtree`` or ``dart``, XGBoost builds a tree model, which is a
|
||||||
|
list of trees and can be sliced into multiple sub-models.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
|
num_classes = 3
|
||||||
|
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||||
|
n_classes=num_classes)
|
||||||
|
dtrain = xgb.DMatrix(data=X, label=y)
|
||||||
|
num_parallel_tree = 4
|
||||||
|
num_boost_round = 16
|
||||||
|
# total number of built trees is num_parallel_tree * num_classes * num_boost_round
|
||||||
|
|
||||||
|
# We build a boosted random forest for classification here.
|
||||||
|
booster = xgb.train({
|
||||||
|
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3},
|
||||||
|
num_boost_round=num_boost_round, dtrain=dtrain)
|
||||||
|
|
||||||
|
# This is the sliced model, containing [3, 7) forests
|
||||||
|
# step is also supported with some limitations like negative step is invalid.
|
||||||
|
sliced: xgb.Booster = booster[3:7]
|
||||||
|
|
||||||
|
# Access individual tree layer
|
||||||
|
trees = [_ for _ in booster]
|
||||||
|
assert len(trees) == num_boost_round
|
||||||
|
|
||||||
|
|
||||||
|
The sliced model is a copy of selected trees, that means the model itself is immutable
|
||||||
|
during slicing. This feature is the basis of `save_best` option in early stopping
|
||||||
|
callback.
|
||||||
@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
|||||||
*/
|
*/
|
||||||
XGB_DLL int XGBoosterFree(BoosterHandle handle);
|
XGB_DLL int XGBoosterFree(BoosterHandle handle);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||||
|
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||||
|
*
|
||||||
|
* \param handle Booster to be sliced.
|
||||||
|
* \param begin_layer start of the slice
|
||||||
|
* \param end_layer end of the slice; end_layer=0 is equivalent to
|
||||||
|
* end_layer=num_boost_round
|
||||||
|
* \param step step size of the slice
|
||||||
|
* \param out Sliced booster.
|
||||||
|
*
|
||||||
|
* \return 0 when success, -1 when failure happens, -2 when index is out of bound.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
||||||
|
int end_layer, int step,
|
||||||
|
BoosterHandle *out);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters
|
* \brief set parameters
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable {
|
|||||||
* \param fo output stream
|
* \param fo output stream
|
||||||
*/
|
*/
|
||||||
virtual void Save(dmlc::Stream* fo) const = 0;
|
virtual void Save(dmlc::Stream* fo) const = 0;
|
||||||
|
/*!
|
||||||
|
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||||
|
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||||
|
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||||
|
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||||
|
* \param out Output gradient booster
|
||||||
|
*/
|
||||||
|
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
|
GradientBooster *out, bool* out_of_bound) const {
|
||||||
|
LOG(FATAL) << "Slice is not supported by current booster.";
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief whether the model allow lazy checkpoint
|
* \brief whether the model allow lazy checkpoint
|
||||||
* return true if model is only updated in DoBoost
|
* return true if model is only updated in DoBoost
|
||||||
|
|||||||
@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
|||||||
* \return whether the model allow lazy checkpoint in rabit.
|
* \return whether the model allow lazy checkpoint in rabit.
|
||||||
*/
|
*/
|
||||||
bool AllowLazyCheckPoint() const;
|
bool AllowLazyCheckPoint() const;
|
||||||
|
/*!
|
||||||
|
* \brief Slice the model.
|
||||||
|
*
|
||||||
|
* See InplacePredict for layer parameters.
|
||||||
|
*
|
||||||
|
* \param step step size between slice.
|
||||||
|
* \param out_of_bound Return true if end layer is out of bound.
|
||||||
|
*
|
||||||
|
* \return a sliced model.
|
||||||
|
*/
|
||||||
|
virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
||||||
|
bool *out_of_bound) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief dump the model in the requested format
|
* \brief dump the model in the requested format
|
||||||
* \param fmap feature map that may help give interpretations of feature
|
* \param fmap feature map that may help give interpretations of feature
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from typing import Callable, List
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from .core import EarlyStopException, CallbackEnv
|
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
|
||||||
from .compat import STRING_TYPES
|
from .compat import STRING_TYPES
|
||||||
|
|
||||||
|
|
||||||
@ -279,9 +279,11 @@ class TrainingCallback(ABC):
|
|||||||
|
|
||||||
def before_training(self, model):
|
def before_training(self, model):
|
||||||
'''Run before training starts.'''
|
'''Run before training starts.'''
|
||||||
|
return model
|
||||||
|
|
||||||
def after_training(self, model):
|
def after_training(self, model):
|
||||||
'''Run after training is finished.'''
|
'''Run after training is finished.'''
|
||||||
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, evals_log):
|
def before_iteration(self, model, epoch, evals_log):
|
||||||
'''Run before each iteration. Return True when training should stop.'''
|
'''Run before each iteration. Return True when training should stop.'''
|
||||||
@ -362,12 +364,24 @@ class CallbackContainer:
|
|||||||
def before_training(self, model):
|
def before_training(self, model):
|
||||||
'''Function called before training.'''
|
'''Function called before training.'''
|
||||||
for c in self.callbacks:
|
for c in self.callbacks:
|
||||||
c.before_training(model=model)
|
model = c.before_training(model=model)
|
||||||
|
msg = 'before_training should return the model'
|
||||||
|
if self.is_cv:
|
||||||
|
assert isinstance(model.cvfolds, list), msg
|
||||||
|
else:
|
||||||
|
assert isinstance(model, Booster), msg
|
||||||
|
return model
|
||||||
|
|
||||||
def after_training(self, model):
|
def after_training(self, model):
|
||||||
'''Function called after training.'''
|
'''Function called after training.'''
|
||||||
for c in self.callbacks:
|
for c in self.callbacks:
|
||||||
c.after_training(model)
|
model = c.after_training(model=model)
|
||||||
|
msg = 'after_training should return the model'
|
||||||
|
if self.is_cv:
|
||||||
|
assert isinstance(model.cvfolds, list), msg
|
||||||
|
else:
|
||||||
|
assert isinstance(model, Booster), msg
|
||||||
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, dtrain, evals):
|
def before_iteration(self, model, epoch, dtrain, evals):
|
||||||
'''Function called before training iteration.'''
|
'''Function called before training iteration.'''
|
||||||
@ -461,7 +475,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
maximize : bool
|
maximize : bool
|
||||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||||
save_best : bool
|
save_best : bool
|
||||||
Placeholder, the feature is not yet supported.
|
Whether training should return the best model or the last model.
|
||||||
'''
|
'''
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rounds,
|
rounds,
|
||||||
@ -473,9 +487,6 @@ class EarlyStopping(TrainingCallback):
|
|||||||
self.metric_name = metric_name
|
self.metric_name = metric_name
|
||||||
self.rounds = rounds
|
self.rounds = rounds
|
||||||
self.save_best = save_best
|
self.save_best = save_best
|
||||||
# https://github.com/dmlc/xgboost/issues/5531
|
|
||||||
assert self.save_best is False, 'save best is not yet supported.'
|
|
||||||
|
|
||||||
self.maximize = maximize
|
self.maximize = maximize
|
||||||
self.stopping_history = {}
|
self.stopping_history = {}
|
||||||
|
|
||||||
@ -525,7 +536,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model: Booster, epoch, evals_log):
|
||||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||||
assert len(evals_log.keys()) >= 1, msg
|
assert len(evals_log.keys()) >= 1, msg
|
||||||
data_name = ''
|
data_name = ''
|
||||||
@ -551,6 +562,14 @@ class EarlyStopping(TrainingCallback):
|
|||||||
score = data_log[metric_name][-1]
|
score = data_log[metric_name][-1]
|
||||||
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||||
|
|
||||||
|
def after_training(self, model: Booster):
|
||||||
|
try:
|
||||||
|
if self.save_best:
|
||||||
|
model = model[: int(model.attr('best_iteration'))]
|
||||||
|
except XGBoostError as e:
|
||||||
|
raise XGBoostError('`save_best` is not applicable to current booster') from e
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class EvaluationMonitor(TrainingCallback):
|
class EvaluationMonitor(TrainingCallback):
|
||||||
'''Print the evaluation result at each iteration.
|
'''Print the evaluation result at each iteration.
|
||||||
@ -684,9 +703,11 @@ class LegacyCallbacks:
|
|||||||
|
|
||||||
def before_training(self, model):
|
def before_training(self, model):
|
||||||
'''Nothing to do for legacy callbacks'''
|
'''Nothing to do for legacy callbacks'''
|
||||||
|
return model
|
||||||
|
|
||||||
def after_training(self, model):
|
def after_training(self, model):
|
||||||
'''Nothing to do for legacy callbacks'''
|
'''Nothing to do for legacy callbacks'''
|
||||||
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, dtrain, evals):
|
def before_iteration(self, model, epoch, dtrain, evals):
|
||||||
'''Called before each iteration.'''
|
'''Called before each iteration.'''
|
||||||
|
|||||||
@ -944,8 +944,8 @@ class Booster(object):
|
|||||||
Parameters for boosters.
|
Parameters for boosters.
|
||||||
cache : list
|
cache : list
|
||||||
List of cache items.
|
List of cache items.
|
||||||
model_file : string or os.PathLike
|
model_file : string/os.PathLike/Booster/bytearray
|
||||||
Path to the model file.
|
Path to the model file if it's string or PathLike.
|
||||||
"""
|
"""
|
||||||
for d in cache:
|
for d in cache:
|
||||||
if not isinstance(d, DMatrix):
|
if not isinstance(d, DMatrix):
|
||||||
@ -1021,6 +1021,43 @@ class Booster(object):
|
|||||||
state['handle'] = handle
|
state['handle'] = handle
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def __getitem__(self, val):
|
||||||
|
if isinstance(val, int):
|
||||||
|
val = slice(val, val+1)
|
||||||
|
if isinstance(val, tuple):
|
||||||
|
raise ValueError('Only supports slicing through 1 dimension.')
|
||||||
|
if not isinstance(val, slice):
|
||||||
|
msg = _expect((int, slice), type(val))
|
||||||
|
raise TypeError(msg)
|
||||||
|
if isinstance(val.start, type(Ellipsis)) or val.start is None:
|
||||||
|
start = 0
|
||||||
|
else:
|
||||||
|
start = val.start
|
||||||
|
if isinstance(val.stop, type(Ellipsis)) or val.stop is None:
|
||||||
|
stop = 0
|
||||||
|
else:
|
||||||
|
stop = val.stop
|
||||||
|
if stop < start:
|
||||||
|
raise ValueError('Invalid slice', val)
|
||||||
|
|
||||||
|
step = val.step if val.step is not None else 1
|
||||||
|
|
||||||
|
start = ctypes.c_int(start)
|
||||||
|
stop = ctypes.c_int(stop)
|
||||||
|
step = ctypes.c_int(step)
|
||||||
|
|
||||||
|
sliced_handle = ctypes.c_void_p()
|
||||||
|
status = _LIB.XGBoosterSlice(self.handle, start, stop, step,
|
||||||
|
ctypes.byref(sliced_handle))
|
||||||
|
if status == -2:
|
||||||
|
raise IndexError('Layer index out of range')
|
||||||
|
_check_call(status)
|
||||||
|
|
||||||
|
sliced = Booster()
|
||||||
|
_check_call(_LIB.XGBoosterFree(sliced.handle))
|
||||||
|
sliced.handle = sliced_handle
|
||||||
|
return sliced
|
||||||
|
|
||||||
def save_config(self):
|
def save_config(self):
|
||||||
'''Output internal parameter configuration of Booster as a JSON
|
'''Output internal parameter configuration of Booster as a JSON
|
||||||
string.
|
string.
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def _train_internal(params, dtrain,
|
|||||||
num_boost_round, feval, evals_result, callbacks,
|
num_boost_round, feval, evals_result, callbacks,
|
||||||
show_stdv=False, cvfolds=None)
|
show_stdv=False, cvfolds=None)
|
||||||
|
|
||||||
callbacks.before_training(bst)
|
bst = callbacks.before_training(bst)
|
||||||
for i in range(start_iteration, num_boost_round):
|
for i in range(start_iteration, num_boost_round):
|
||||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||||
break
|
break
|
||||||
@ -125,7 +125,7 @@ def _train_internal(params, dtrain,
|
|||||||
bst.save_rabit_checkpoint()
|
bst.save_rabit_checkpoint()
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
callbacks.after_training(bst)
|
bst = callbacks.after_training(bst)
|
||||||
|
|
||||||
if evals_result is not None and is_new_callback:
|
if evals_result is not None and is_new_callback:
|
||||||
evals_result.update(callbacks.history)
|
evals_result.update(callbacks.history)
|
||||||
@ -495,9 +495,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
verbose_eval, early_stopping_rounds, maximize, 0,
|
verbose_eval, early_stopping_rounds, maximize, 0,
|
||||||
num_boost_round, feval, None, callbacks,
|
num_boost_round, feval, None, callbacks,
|
||||||
show_stdv=show_stdv, cvfolds=cvfolds)
|
show_stdv=show_stdv, cvfolds=cvfolds)
|
||||||
callbacks.before_training(cvfolds)
|
|
||||||
|
|
||||||
booster = _PackedBooster(cvfolds)
|
booster = _PackedBooster(cvfolds)
|
||||||
|
callbacks.before_training(booster)
|
||||||
|
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
if callbacks.before_iteration(booster, i, dtrain, None):
|
if callbacks.before_iteration(booster, i, dtrain, None):
|
||||||
@ -524,4 +523,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
results = pd.DataFrame.from_dict(results)
|
results = pd.DataFrame.from_dict(results)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
callbacks.after_training(booster)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -730,6 +730,22 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
||||||
|
int end_layer, int step,
|
||||||
|
BoosterHandle *out) {
|
||||||
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
|
auto* learner = static_cast<Learner*>(handle);
|
||||||
|
bool out_of_bound = false;
|
||||||
|
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
||||||
|
if (out_of_bound) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
CHECK(p_out);
|
||||||
|
*out = p_out;
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap,
|
inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap,
|
||||||
int with_stats, const char *format,
|
int with_stats, const char *format,
|
||||||
xgboost::bst_ulong *len,
|
xgboost::bst_ulong *len,
|
||||||
|
|||||||
@ -398,6 +398,38 @@ void GBTree::SaveModel(Json* p_out) const {
|
|||||||
model_.SaveModel(&model);
|
model_.SaveModel(&model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
|
GradientBooster *out, bool* out_of_bound) const {
|
||||||
|
CHECK(configured_);
|
||||||
|
CHECK(out);
|
||||||
|
|
||||||
|
auto p_gbtree = dynamic_cast<GBTree *>(out);
|
||||||
|
CHECK(p_gbtree);
|
||||||
|
GBTreeModel &out_model = p_gbtree->model_;
|
||||||
|
auto layer_trees = this->LayerTrees();
|
||||||
|
|
||||||
|
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
|
||||||
|
CHECK_GE(layer_end, layer_begin);
|
||||||
|
CHECK_GE(step, 1);
|
||||||
|
int32_t n_layers = (layer_end - layer_begin) / step;
|
||||||
|
std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees;
|
||||||
|
out_trees.resize(layer_trees * n_layers);
|
||||||
|
std::vector<int32_t> &out_trees_info = out_model.tree_info;
|
||||||
|
out_trees_info.resize(layer_trees * n_layers);
|
||||||
|
out_model.param.num_trees = out_model.trees.size();
|
||||||
|
CHECK(this->model_.trees_to_update.empty());
|
||||||
|
|
||||||
|
*out_of_bound = detail::SliceTrees(
|
||||||
|
layer_begin, layer_end, step, this->model_, tparam_, layer_trees,
|
||||||
|
[&](auto const &in_it, auto const &out_it) {
|
||||||
|
auto new_tree =
|
||||||
|
std::make_unique<RegTree>(*this->model_.trees.at(in_it));
|
||||||
|
bst_group_t group = this->model_.tree_info[in_it];
|
||||||
|
out_trees.at(out_it) = std::move(new_tree);
|
||||||
|
out_trees_info.at(out_it) = group;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void GBTree::PredictBatch(DMatrix* p_fmat,
|
void GBTree::PredictBatch(DMatrix* p_fmat,
|
||||||
PredictionCacheEntry* out_preds,
|
PredictionCacheEntry* out_preds,
|
||||||
bool,
|
bool,
|
||||||
@ -494,6 +526,22 @@ class Dart : public GBTree {
|
|||||||
dparam_.UpdateAllowUnknown(cfg);
|
dparam_.UpdateAllowUnknown(cfg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
|
GradientBooster *out, bool* out_of_bound) const final {
|
||||||
|
GBTree::Slice(layer_begin, layer_end, step, out, out_of_bound);
|
||||||
|
if (*out_of_bound) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto p_dart = dynamic_cast<Dart*>(out);
|
||||||
|
CHECK(p_dart);
|
||||||
|
CHECK(p_dart->weight_drop_.empty());
|
||||||
|
detail::SliceTrees(
|
||||||
|
layer_begin, layer_end, step, model_, tparam_, this->LayerTrees(),
|
||||||
|
[&](auto const& in_it, auto const&) {
|
||||||
|
p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void SaveModel(Json *p_out) const override {
|
void SaveModel(Json *p_out) const override {
|
||||||
auto &out = *p_out;
|
auto &out = *p_out;
|
||||||
out["name"] = String("dart");
|
out["name"] = String("dart");
|
||||||
|
|||||||
@ -152,6 +152,50 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
// From here on, layer becomes concrete trees.
|
||||||
|
inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
|
||||||
|
GBTreeTrainParam const &tparam,
|
||||||
|
size_t layer_begin,
|
||||||
|
size_t layer_end) {
|
||||||
|
bst_group_t groups = model.learner_model_param->num_output_group;
|
||||||
|
uint32_t tree_begin = layer_begin * groups * tparam.num_parallel_tree;
|
||||||
|
uint32_t tree_end = layer_end * groups * tparam.num_parallel_tree;
|
||||||
|
if (tree_end == 0) {
|
||||||
|
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||||
|
}
|
||||||
|
CHECK_LT(tree_begin, tree_end);
|
||||||
|
return {tree_begin, tree_end};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call fn for each pair of input output tree. Return true if index is out of bound.
|
||||||
|
template <typename Func>
|
||||||
|
inline bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
|
GBTreeModel const &model, GBTreeTrainParam const &tparam,
|
||||||
|
uint32_t layer_trees, Func fn) {
|
||||||
|
uint32_t tree_begin, tree_end;
|
||||||
|
std::tie(tree_begin, tree_end) = detail::LayerToTree(model, tparam, layer_begin, layer_end);
|
||||||
|
if (tree_end > model.trees.size()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
layer_end = layer_end == 0 ? model.trees.size() / layer_trees : layer_end;
|
||||||
|
uint32_t n_layers = (layer_end - layer_begin) / step;
|
||||||
|
int32_t in_it = tree_begin;
|
||||||
|
int32_t out_it = 0;
|
||||||
|
for (uint32_t l = 0; l < n_layers; ++l) {
|
||||||
|
for (uint32_t i = 0; i < layer_trees; ++i) {
|
||||||
|
CHECK_LT(in_it, tree_end);
|
||||||
|
fn(in_it, out_it);
|
||||||
|
out_it++;
|
||||||
|
in_it++;
|
||||||
|
}
|
||||||
|
in_it += (step - 1) * layer_trees;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
// gradient boosted trees
|
// gradient boosted trees
|
||||||
class GBTree : public GradientBooster {
|
class GBTree : public GradientBooster {
|
||||||
public:
|
public:
|
||||||
@ -200,6 +244,15 @@ class GBTree : public GradientBooster {
|
|||||||
return model_.learner_model_param->num_output_group == 1;
|
return model_.learner_model_param->num_output_group == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Number of trees per layer.
|
||||||
|
auto LayerTrees() const {
|
||||||
|
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
|
||||||
|
return n_trees;
|
||||||
|
}
|
||||||
|
// slice the trees, out must be already allocated
|
||||||
|
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
|
GradientBooster *out, bool* out_of_bound) const override;
|
||||||
|
|
||||||
void PredictBatch(DMatrix* p_fmat,
|
void PredictBatch(DMatrix* p_fmat,
|
||||||
PredictionCacheEntry* out_preds,
|
PredictionCacheEntry* out_preds,
|
||||||
bool training,
|
bool training,
|
||||||
@ -210,13 +263,8 @@ class GBTree : public GradientBooster {
|
|||||||
uint32_t layer_begin,
|
uint32_t layer_begin,
|
||||||
unsigned layer_end) const override {
|
unsigned layer_end) const override {
|
||||||
CHECK(configured_);
|
CHECK(configured_);
|
||||||
// From here on, layer becomes concrete trees.
|
uint32_t tree_begin, tree_end;
|
||||||
bst_group_t groups = model_.learner_model_param->num_output_group;
|
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||||
uint32_t tree_begin = layer_begin * groups * tparam_.num_parallel_tree;
|
|
||||||
uint32_t tree_end = layer_end * groups * tparam_.num_parallel_tree;
|
|
||||||
if (tree_end == 0 || tree_end > model_.trees.size()) {
|
|
||||||
tree_end = static_cast<uint32_t>(model_.trees.size());
|
|
||||||
}
|
|
||||||
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
||||||
tree_begin, tree_end);
|
tree_begin, tree_end);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,10 +6,10 @@
|
|||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "gbtree_model.h"
|
#include "gbtree_model.h"
|
||||||
|
#include "gbtree.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace gbm {
|
namespace gbm {
|
||||||
|
|
||||||
void GBTreeModel::Save(dmlc::Stream* fo) const {
|
void GBTreeModel::Save(dmlc::Stream* fo) const {
|
||||||
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));
|
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2019 by Contributors
|
* Copyright 2017-2020 by Contributors
|
||||||
* \file gbtree_model.h
|
* \file gbtree_model.h
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
|
#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
|
||||||
@ -22,6 +22,7 @@ namespace xgboost {
|
|||||||
class Json;
|
class Json;
|
||||||
|
|
||||||
namespace gbm {
|
namespace gbm {
|
||||||
|
|
||||||
/*! \brief model parameters */
|
/*! \brief model parameters */
|
||||||
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -971,6 +971,26 @@ class LearnerImpl : public LearnerIO {
|
|||||||
return gbm_->DumpModel(fmap, with_stats, format);
|
return gbm_->DumpModel(fmap, with_stats, format);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
||||||
|
bool *out_of_bound) override {
|
||||||
|
this->Configure();
|
||||||
|
CHECK_GE(begin_layer, 0);
|
||||||
|
auto *out_impl = new LearnerImpl({});
|
||||||
|
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
|
||||||
|
this->tparam_.booster, &this->generic_parameters_,
|
||||||
|
&this->learner_model_param_));
|
||||||
|
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
|
||||||
|
out_impl->gbm_ = std::move(gbm);
|
||||||
|
Json config { Object() };
|
||||||
|
this->SaveConfig(&config);
|
||||||
|
out_impl->mparam_ = this->mparam_;
|
||||||
|
out_impl->attributes_ = this->attributes_;
|
||||||
|
out_impl->learner_model_param_ = this->learner_model_param_;
|
||||||
|
out_impl->LoadConfig(config);
|
||||||
|
out_impl->Configure();
|
||||||
|
return out_impl;
|
||||||
|
}
|
||||||
|
|
||||||
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
||||||
monitor_.Start("UpdateOneIter");
|
monitor_.Start("UpdateOneIter");
|
||||||
TrainingObserver::Instance().Update(iter);
|
TrainingObserver::Instance().Update(iter);
|
||||||
|
|||||||
@ -154,9 +154,9 @@ TEST(GBTree, JsonIO) {
|
|||||||
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
|
||||||
|
|
||||||
auto const& gbtree_model = model["model"]["model"];
|
auto const& gbtree_model = model["model"]["model"];
|
||||||
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
|
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1ul);
|
||||||
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
|
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
|
||||||
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
|
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1ul);
|
||||||
|
|
||||||
auto j_train_param = model["config"]["gbtree_train_param"];
|
auto j_train_param = model["config"]["gbtree_train_param"];
|
||||||
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
|
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
|
||||||
@ -194,7 +194,7 @@ TEST(Dart, JsonIO) {
|
|||||||
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
|
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
|
||||||
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
|
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
|
||||||
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
|
||||||
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
|
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0ul);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Dart, Prediction) {
|
TEST(Dart, Prediction) {
|
||||||
@ -230,4 +230,122 @@ TEST(Dart, Prediction) {
|
|||||||
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Json, Json> TestModelSlice(std::string booster) {
|
||||||
|
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
|
||||||
|
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses);
|
||||||
|
|
||||||
|
int32_t kIters = 10;
|
||||||
|
std::unique_ptr<Learner> learner {
|
||||||
|
Learner::Create({m})
|
||||||
|
};
|
||||||
|
learner->SetParams(Args{{"booster", booster},
|
||||||
|
{"tree_method", "hist"},
|
||||||
|
{"num_parallel_tree", std::to_string(kForest)},
|
||||||
|
{"num_class", std::to_string(kClasses)},
|
||||||
|
{"subsample", "0.5"},
|
||||||
|
{"max_depth", "2"}});
|
||||||
|
|
||||||
|
for (auto i = 0; i < kIters; ++i) {
|
||||||
|
learner->UpdateOneIter(i, m);
|
||||||
|
}
|
||||||
|
|
||||||
|
Json model{Object()};
|
||||||
|
Json config{Object()};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
bool out_of_bound = false;
|
||||||
|
|
||||||
|
size_t constexpr kSliceStart = 2, kSliceEnd = 8, kStep = 3;
|
||||||
|
std::unique_ptr<Learner> sliced {learner->Slice(kSliceStart, kSliceEnd, kStep, &out_of_bound)};
|
||||||
|
Json sliced_model{Object()};
|
||||||
|
sliced->SaveModel(&sliced_model);
|
||||||
|
|
||||||
|
auto get_shape = [&](Json const& model) {
|
||||||
|
if (booster == "gbtree") {
|
||||||
|
return get<Object const>(model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]);
|
||||||
|
} else {
|
||||||
|
return get<Object const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["gbtree_model_param"]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto const& model_shape = get_shape(sliced_model);
|
||||||
|
CHECK_EQ(get<String const>(model_shape.at("num_trees")), std::to_string(2 * kClasses * kForest));
|
||||||
|
|
||||||
|
Json sliced_config {Object()};
|
||||||
|
sliced->SaveConfig(&sliced_config);
|
||||||
|
CHECK_EQ(sliced_config, config);
|
||||||
|
|
||||||
|
auto get_trees = [&](Json const& model) {
|
||||||
|
if (booster == "gbtree") {
|
||||||
|
return get<Array const>(model["learner"]["gradient_booster"]["model"]["trees"]);
|
||||||
|
} else {
|
||||||
|
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_info = [&](Json const& model) {
|
||||||
|
if (booster == "gbtree") {
|
||||||
|
return get<Array const>(model["learner"]["gradient_booster"]["model"]["tree_info"]);
|
||||||
|
} else {
|
||||||
|
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["tree_info"]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto const &sliced_trees = get_trees(sliced_model);
|
||||||
|
CHECK_EQ(sliced_trees.size(), 2 * kClasses * kForest);
|
||||||
|
|
||||||
|
auto constexpr kLayerSize = kClasses * kForest;
|
||||||
|
auto const &sliced_info = get_info(sliced_model);
|
||||||
|
|
||||||
|
for (size_t layer = 0; layer < 2; ++layer) {
|
||||||
|
for (size_t j = 0; j < kClasses; ++j) {
|
||||||
|
for (size_t k = 0; k < kForest; ++k) {
|
||||||
|
auto idx = layer * kLayerSize + j * kForest + k;
|
||||||
|
auto const &group = get<Integer const>(sliced_info.at(idx));
|
||||||
|
CHECK_EQ(static_cast<size_t>(group), j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const& trees = get_trees(model);
|
||||||
|
|
||||||
|
// Sliced layers are [2, 5]
|
||||||
|
auto begin = kLayerSize * kSliceStart;
|
||||||
|
auto end = begin + kLayerSize;
|
||||||
|
auto j = 0;
|
||||||
|
for (size_t i = begin; i < end; ++i) {
|
||||||
|
Json tree = trees[i];
|
||||||
|
tree["id"] = Integer(0); // id is different, we set it to 0 to allow comparison.
|
||||||
|
auto sliced_tree = sliced_trees[j];
|
||||||
|
sliced_tree["id"] = Integer(0);
|
||||||
|
CHECK_EQ(tree, sliced_tree);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
begin = kLayerSize * (kSliceStart + kStep);
|
||||||
|
end = begin + kLayerSize;
|
||||||
|
for (size_t i = begin; i < end; ++i) {
|
||||||
|
Json tree = trees[i];
|
||||||
|
tree["id"] = Integer(0);
|
||||||
|
auto sliced_tree = sliced_trees[j];
|
||||||
|
sliced_tree["id"] = Integer(0);
|
||||||
|
CHECK_EQ(tree, sliced_tree);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(model, sliced_model);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GBTree, Slice) {
|
||||||
|
TestModelSlice("gbtree");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Dart, Slice) {
|
||||||
|
Json model, sliced_model;
|
||||||
|
std::tie(model, sliced_model) = TestModelSlice("dart");
|
||||||
|
auto const& weights = get<Array const>(model["learner"]["gradient_booster"]["weight_drop"]);
|
||||||
|
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
|
||||||
|
ASSERT_EQ(weights.size(), trees.size());
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -118,7 +118,7 @@ TEST(Learner, Configuration) {
|
|||||||
|
|
||||||
// eval_metric is not part of configuration
|
// eval_metric is not part of configuration
|
||||||
auto attr_names = learner->GetConfigurationArguments();
|
auto attr_names = learner->GetConfigurationArguments();
|
||||||
ASSERT_EQ(attr_names.size(), 1);
|
ASSERT_EQ(attr_names.size(), 1ul);
|
||||||
ASSERT_EQ(attr_names.find(emetric), attr_names.cend());
|
ASSERT_EQ(attr_names.find(emetric), attr_names.cend());
|
||||||
ASSERT_EQ(attr_names.at("foo"), "bar");
|
ASSERT_EQ(attr_names.at("foo"), "bar");
|
||||||
}
|
}
|
||||||
@ -127,7 +127,7 @@ TEST(Learner, Configuration) {
|
|||||||
std::unique_ptr<Learner> learner { Learner::Create({nullptr}) };
|
std::unique_ptr<Learner> learner { Learner::Create({nullptr}) };
|
||||||
learner->SetParams({{"foo", "bar"}, {emetric, "auc"}, {emetric, "entropy"}, {emetric, "KL"}});
|
learner->SetParams({{"foo", "bar"}, {emetric, "auc"}, {emetric, "entropy"}, {emetric, "KL"}});
|
||||||
auto attr_names = learner->GetConfigurationArguments();
|
auto attr_names = learner->GetConfigurationArguments();
|
||||||
ASSERT_EQ(attr_names.size(), 1);
|
ASSERT_EQ(attr_names.size(), 1ul);
|
||||||
ASSERT_EQ(attr_names.at("foo"), "bar");
|
ASSERT_EQ(attr_names.at("foo"), "bar");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -181,7 +181,7 @@ TEST(Learner, JsonModelIO) {
|
|||||||
learner->SaveModel(&new_in);
|
learner->SaveModel(&new_in);
|
||||||
|
|
||||||
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
|
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
|
||||||
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
|
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1ul);
|
||||||
ASSERT_EQ(out, new_in);
|
ASSERT_EQ(out, new_in);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -333,5 +333,4 @@ TEST(Learner, Seed) {
|
|||||||
ASSERT_EQ(std::to_string(seed),
|
ASSERT_EQ(std::to_string(seed),
|
||||||
get<String>(config["learner"]["generic_param"]["seed"]));
|
get<String>(config["learner"]["generic_param"]["seed"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def json_model(model_path, parameters):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels:
|
||||||
def test_glm(self):
|
def test_glm(self):
|
||||||
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
||||||
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
|
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
|
||||||
@ -209,12 +209,14 @@ class TestModels(unittest.TestCase):
|
|||||||
|
|
||||||
bst = xgb.train([], dm1)
|
bst = xgb.train([], dm1)
|
||||||
bst.predict(dm1) # success
|
bst.predict(dm1) # success
|
||||||
self.assertRaises(ValueError, bst.predict, dm2)
|
with pytest.raises(ValueError):
|
||||||
|
bst.predict(dm2)
|
||||||
bst.predict(dm1) # success
|
bst.predict(dm1) # success
|
||||||
|
|
||||||
bst = xgb.train([], dm2)
|
bst = xgb.train([], dm2)
|
||||||
bst.predict(dm2) # success
|
bst.predict(dm2) # success
|
||||||
self.assertRaises(ValueError, bst.predict, dm1)
|
with pytest.raises(ValueError):
|
||||||
|
bst.predict(dm1)
|
||||||
bst.predict(dm2) # success
|
bst.predict(dm2) # success
|
||||||
|
|
||||||
def test_model_binary_io(self):
|
def test_model_binary_io(self):
|
||||||
@ -325,3 +327,96 @@ class TestModels(unittest.TestCase):
|
|||||||
parameters = {'tree_method': 'hist', 'booster': 'dart',
|
parameters = {'tree_method': 'hist', 'booster': 'dart',
|
||||||
'objective': 'multi:softmax'}
|
'objective': 'multi:softmax'}
|
||||||
validate_model(parameters)
|
validate_model(parameters)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('booster', ['gbtree', 'dart'])
|
||||||
|
def test_slice(self, booster):
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
|
num_classes = 3
|
||||||
|
X, y = make_classification(n_samples=1000, n_informative=5,
|
||||||
|
n_classes=num_classes)
|
||||||
|
dtrain = xgb.DMatrix(data=X, label=y)
|
||||||
|
num_parallel_tree = 4
|
||||||
|
num_boost_round = 16
|
||||||
|
total_trees = num_parallel_tree * num_classes * num_boost_round
|
||||||
|
booster = xgb.train({
|
||||||
|
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
|
||||||
|
'objective': 'multi:softprob'},
|
||||||
|
num_boost_round=num_boost_round, dtrain=dtrain)
|
||||||
|
assert len(booster.get_dump()) == total_trees
|
||||||
|
beg = 3
|
||||||
|
end = 7
|
||||||
|
sliced: xgb.Booster = booster[beg: end]
|
||||||
|
|
||||||
|
sliced_trees = (end - beg) * num_parallel_tree * num_classes
|
||||||
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
|
sliced_trees = sliced_trees // 2
|
||||||
|
sliced: xgb.Booster = booster[beg: end: 2]
|
||||||
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
|
sliced: xgb.Booster = booster[beg: ...]
|
||||||
|
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||||
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
|
sliced: xgb.Booster = booster[beg:]
|
||||||
|
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
|
||||||
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
|
sliced: xgb.Booster = booster[:end]
|
||||||
|
sliced_trees = end * num_parallel_tree * num_classes
|
||||||
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
|
sliced: xgb.Booster = booster[...:end]
|
||||||
|
sliced_trees = end * num_parallel_tree * num_classes
|
||||||
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r'>= 0'):
|
||||||
|
booster[-1: 0]
|
||||||
|
|
||||||
|
# we do not accept empty slice.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
booster[1:1]
|
||||||
|
# stop can not be smaller than begin
|
||||||
|
with pytest.raises(ValueError, match=r'Invalid.*'):
|
||||||
|
booster[3:0]
|
||||||
|
with pytest.raises(ValueError, match=r'Invalid.*'):
|
||||||
|
booster[3:-1]
|
||||||
|
# negative step is not supported.
|
||||||
|
with pytest.raises(ValueError, match=r'.*>= 1.*'):
|
||||||
|
booster[0:2:-1]
|
||||||
|
# step can not be 0.
|
||||||
|
with pytest.raises(ValueError, match=r'.*>= 1.*'):
|
||||||
|
booster[0:2:0]
|
||||||
|
|
||||||
|
trees = [_ for _ in booster]
|
||||||
|
assert len(trees) == num_boost_round
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
booster["wrong type"]
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
booster[:num_boost_round+1]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
booster[1, 2] # too many dims
|
||||||
|
# setitem is not implemented as model is immutable during slicing.
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
booster[...:end] = booster
|
||||||
|
|
||||||
|
sliced_0 = booster[1:3]
|
||||||
|
sliced_1 = booster[3:7]
|
||||||
|
|
||||||
|
predt_0 = sliced_0.predict(dtrain, output_margin=True)
|
||||||
|
predt_1 = sliced_1.predict(dtrain, output_margin=True)
|
||||||
|
|
||||||
|
merged = predt_0 + predt_1 - 0.5 # base score.
|
||||||
|
single = booster[1:7].predict(dtrain, output_margin=True)
|
||||||
|
np.testing.assert_allclose(merged, single, atol=1e-6)
|
||||||
|
|
||||||
|
sliced_0 = booster[1:7:2] # 1,3,5
|
||||||
|
sliced_1 = booster[2:8:2] # 2,4,6
|
||||||
|
|
||||||
|
predt_0 = sliced_0.predict(dtrain, output_margin=True)
|
||||||
|
predt_1 = sliced_1.predict(dtrain, output_margin=True)
|
||||||
|
|
||||||
|
merged = predt_0 + predt_1 - 0.5
|
||||||
|
single = booster[1:7].predict(dtrain, output_margin=True)
|
||||||
|
np.testing.assert_allclose(merged, single, atol=1e-6)
|
||||||
|
|||||||
@ -113,6 +113,35 @@ class TestCallbacks(unittest.TestCase):
|
|||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def test_early_stopping_save_best_model(self):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
n_estimators = 100
|
||||||
|
cls = xgb.XGBClassifier(n_estimators=n_estimators)
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
save_best=True)
|
||||||
|
cls.fit(X, y, eval_set=[(X, y)],
|
||||||
|
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
|
||||||
|
booster = cls.get_booster()
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) == booster.best_iteration
|
||||||
|
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
save_best=True)
|
||||||
|
cls = xgb.XGBClassifier(booster='gblinear', n_estimators=10)
|
||||||
|
self.assertRaises(ValueError, lambda: cls.fit(X, y, eval_set=[(X, y)],
|
||||||
|
eval_metric=tm.eval_error_metric,
|
||||||
|
callbacks=[early_stop]))
|
||||||
|
|
||||||
|
# No error
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
save_best=False)
|
||||||
|
xgb.XGBClassifier(booster='gblinear', n_estimators=10).fit(
|
||||||
|
X, y, eval_set=[(X, y)],
|
||||||
|
eval_metric=tm.eval_error_metric,
|
||||||
|
callbacks=[early_stop])
|
||||||
|
|
||||||
def run_eta_decay(self, tree_method, deprecated_callback):
|
def run_eta_decay(self, tree_method, deprecated_callback):
|
||||||
if deprecated_callback:
|
if deprecated_callback:
|
||||||
scheduler = xgb.callback.reset_learning_rate
|
scheduler = xgb.callback.reset_learning_rate
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user