Support slicing tree model (#6302)

This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone.

* Implement the save_best option in early stopping.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-11-03 02:27:39 -05:00 committed by GitHub
parent 29745c6df2
commit 2cc9662005
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 550 additions and 37 deletions

View File

@ -7,9 +7,9 @@ package. In XGBoost 1.3, a new callback interface is designed for Python packag
provides the flexiblity of designing various extension for training. Also, XGBoost has a
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
#######################
Using builtin callbacks
#######################
-----------------------
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
``verbose``/``verbose_eval``, when specified the training procedure will define the
@ -50,9 +50,9 @@ this callback function directly into XGBoost:
dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
##########################
Defining your own callback
##########################
--------------------------
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
defined callbacks should inherit this class and override corresponding methods. There's a

View File

@ -12,4 +12,5 @@ Contents
python_intro
python_api
callbacks
model
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>

38
doc/python/model.rst Normal file
View File

@ -0,0 +1,38 @@
#####
Model
#####
Slice tree model
----------------
When ``booster`` is set to ``gbtree`` or ``dart``, XGBoost builds a tree model, which is a
list of trees and can be sliced into multiple sub-models.
.. code-block:: python
from sklearn.datasets import make_classification
num_classes = 3
X, y = make_classification(n_samples=1000, n_informative=5,
n_classes=num_classes)
dtrain = xgb.DMatrix(data=X, label=y)
num_parallel_tree = 4
num_boost_round = 16
# total number of built trees is num_parallel_tree * num_classes * num_boost_round
# We build a boosted random forest for classification here.
booster = xgb.train({
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3},
num_boost_round=num_boost_round, dtrain=dtrain)
# This is the sliced model, containing [3, 7) forests
# step is also supported with some limitations like negative step is invalid.
sliced: xgb.Booster = booster[3:7]
# Access individual tree layer
trees = [_ for _ in booster]
assert len(trees) == num_boost_round
The sliced model is a copy of selected trees, that means the model itself is immutable
during slicing. This feature is the basis of `save_best` option in early stopping
callback.

View File

@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
*/
XGB_DLL int XGBoosterFree(BoosterHandle handle);
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
*
* \param handle Booster to be sliced.
* \param begin_layer start of the slice
* \param end_layer end of the slice; end_layer=0 is equivalent to
* end_layer=num_boost_round
* \param step step size of the slice
* \param out Sliced booster.
*
* \return 0 when success, -1 when failure happens, -2 when index is out of bound.
*/
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out);
/*!
* \brief set parameters
* \param handle handle

View File

@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable {
* \param fo output stream
*/
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster
*/
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const {
LOG(FATAL) << "Slice is not supported by current booster.";
}
/*!
* \brief whether the model allow lazy checkpoint
* return true if model is only updated in DoBoost

View File

@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \return whether the model allow lazy checkpoint in rabit.
*/
bool AllowLazyCheckPoint() const;
/*!
* \brief Slice the model.
*
* See InplacePredict for layer parameters.
*
* \param step step size between slice.
* \param out_of_bound Return true if end layer is out of bound.
*
* \return a sliced model.
*/
virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
bool *out_of_bound) = 0;
/*!
* \brief dump the model in the requested format
* \param fmap feature map that may help give interpretations of feature

View File

@ -10,7 +10,7 @@ from typing import Callable, List
import numpy
from . import rabit
from .core import EarlyStopException, CallbackEnv
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
from .compat import STRING_TYPES
@ -279,9 +279,11 @@ class TrainingCallback(ABC):
def before_training(self, model):
'''Run before training starts.'''
return model
def after_training(self, model):
'''Run after training is finished.'''
return model
def before_iteration(self, model, epoch, evals_log):
'''Run before each iteration. Return True when training should stop.'''
@ -362,12 +364,24 @@ class CallbackContainer:
def before_training(self, model):
'''Function called before training.'''
for c in self.callbacks:
c.before_training(model=model)
model = c.before_training(model=model)
msg = 'before_training should return the model'
if self.is_cv:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
return model
def after_training(self, model):
'''Function called after training.'''
for c in self.callbacks:
c.after_training(model)
model = c.after_training(model=model)
msg = 'after_training should return the model'
if self.is_cv:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
return model
def before_iteration(self, model, epoch, dtrain, evals):
'''Function called before training iteration.'''
@ -461,7 +475,7 @@ class EarlyStopping(TrainingCallback):
maximize : bool
Whether to maximize evaluation metric. None means auto (discouraged).
save_best : bool
Placeholder, the feature is not yet supported.
Whether training should return the best model or the last model.
'''
def __init__(self,
rounds,
@ -473,9 +487,6 @@ class EarlyStopping(TrainingCallback):
self.metric_name = metric_name
self.rounds = rounds
self.save_best = save_best
# https://github.com/dmlc/xgboost/issues/5531
assert self.save_best is False, 'save best is not yet supported.'
self.maximize = maximize
self.stopping_history = {}
@ -525,7 +536,7 @@ class EarlyStopping(TrainingCallback):
return True
return False
def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model: Booster, epoch, evals_log):
msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg
data_name = ''
@ -551,6 +562,14 @@ class EarlyStopping(TrainingCallback):
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)
def after_training(self, model: Booster):
try:
if self.save_best:
model = model[: int(model.attr('best_iteration'))]
except XGBoostError as e:
raise XGBoostError('`save_best` is not applicable to current booster') from e
return model
class EvaluationMonitor(TrainingCallback):
'''Print the evaluation result at each iteration.
@ -684,9 +703,11 @@ class LegacyCallbacks:
def before_training(self, model):
'''Nothing to do for legacy callbacks'''
return model
def after_training(self, model):
'''Nothing to do for legacy callbacks'''
return model
def before_iteration(self, model, epoch, dtrain, evals):
'''Called before each iteration.'''

View File

@ -944,8 +944,8 @@ class Booster(object):
Parameters for boosters.
cache : list
List of cache items.
model_file : string or os.PathLike
Path to the model file.
model_file : string/os.PathLike/Booster/bytearray
Path to the model file if it's string or PathLike.
"""
for d in cache:
if not isinstance(d, DMatrix):
@ -1021,6 +1021,43 @@ class Booster(object):
state['handle'] = handle
self.__dict__.update(state)
def __getitem__(self, val):
if isinstance(val, int):
val = slice(val, val+1)
if isinstance(val, tuple):
raise ValueError('Only supports slicing through 1 dimension.')
if not isinstance(val, slice):
msg = _expect((int, slice), type(val))
raise TypeError(msg)
if isinstance(val.start, type(Ellipsis)) or val.start is None:
start = 0
else:
start = val.start
if isinstance(val.stop, type(Ellipsis)) or val.stop is None:
stop = 0
else:
stop = val.stop
if stop < start:
raise ValueError('Invalid slice', val)
step = val.step if val.step is not None else 1
start = ctypes.c_int(start)
stop = ctypes.c_int(stop)
step = ctypes.c_int(step)
sliced_handle = ctypes.c_void_p()
status = _LIB.XGBoosterSlice(self.handle, start, stop, step,
ctypes.byref(sliced_handle))
if status == -2:
raise IndexError('Layer index out of range')
_check_call(status)
sliced = Booster()
_check_call(_LIB.XGBoosterFree(sliced.handle))
sliced.handle = sliced_handle
return sliced
def save_config(self):
'''Output internal parameter configuration of Booster as a JSON
string.

View File

@ -103,7 +103,7 @@ def _train_internal(params, dtrain,
num_boost_round, feval, evals_result, callbacks,
show_stdv=False, cvfolds=None)
callbacks.before_training(bst)
bst = callbacks.before_training(bst)
for i in range(start_iteration, num_boost_round):
if callbacks.before_iteration(bst, i, dtrain, evals):
break
@ -125,7 +125,7 @@ def _train_internal(params, dtrain,
bst.save_rabit_checkpoint()
version += 1
callbacks.after_training(bst)
bst = callbacks.after_training(bst)
if evals_result is not None and is_new_callback:
evals_result.update(callbacks.history)
@ -495,9 +495,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
verbose_eval, early_stopping_rounds, maximize, 0,
num_boost_round, feval, None, callbacks,
show_stdv=show_stdv, cvfolds=cvfolds)
callbacks.before_training(cvfolds)
booster = _PackedBooster(cvfolds)
callbacks.before_training(booster)
for i in range(num_boost_round):
if callbacks.before_iteration(booster, i, dtrain, None):
@ -524,4 +523,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
results = pd.DataFrame.from_dict(results)
except ImportError:
pass
callbacks.after_training(booster)
return results

View File

@ -730,6 +730,22 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_END();
}
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out) {
API_BEGIN();
CHECK_HANDLE();
auto* learner = static_cast<Learner*>(handle);
bool out_of_bound = false;
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
if (out_of_bound) {
return -2;
}
CHECK(p_out);
*out = p_out;
API_END();
}
inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap,
int with_stats, const char *format,
xgboost::bst_ulong *len,

View File

@ -398,6 +398,38 @@ void GBTree::SaveModel(Json* p_out) const {
model_.SaveModel(&model);
}
void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const {
CHECK(configured_);
CHECK(out);
auto p_gbtree = dynamic_cast<GBTree *>(out);
CHECK(p_gbtree);
GBTreeModel &out_model = p_gbtree->model_;
auto layer_trees = this->LayerTrees();
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
CHECK_GE(layer_end, layer_begin);
CHECK_GE(step, 1);
int32_t n_layers = (layer_end - layer_begin) / step;
std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees;
out_trees.resize(layer_trees * n_layers);
std::vector<int32_t> &out_trees_info = out_model.tree_info;
out_trees_info.resize(layer_trees * n_layers);
out_model.param.num_trees = out_model.trees.size();
CHECK(this->model_.trees_to_update.empty());
*out_of_bound = detail::SliceTrees(
layer_begin, layer_end, step, this->model_, tparam_, layer_trees,
[&](auto const &in_it, auto const &out_it) {
auto new_tree =
std::make_unique<RegTree>(*this->model_.trees.at(in_it));
bst_group_t group = this->model_.tree_info[in_it];
out_trees.at(out_it) = std::move(new_tree);
out_trees_info.at(out_it) = group;
});
}
void GBTree::PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds,
bool,
@ -494,6 +526,22 @@ class Dart : public GBTree {
dparam_.UpdateAllowUnknown(cfg);
}
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const final {
GBTree::Slice(layer_begin, layer_end, step, out, out_of_bound);
if (*out_of_bound) {
return;
}
auto p_dart = dynamic_cast<Dart*>(out);
CHECK(p_dart);
CHECK(p_dart->weight_drop_.empty());
detail::SliceTrees(
layer_begin, layer_end, step, model_, tparam_, this->LayerTrees(),
[&](auto const& in_it, auto const&) {
p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it));
});
}
void SaveModel(Json *p_out) const override {
auto &out = *p_out;
out["name"] = String("dart");

View File

@ -152,6 +152,50 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
}
};
namespace detail {
// From here on, layer becomes concrete trees.
inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
GBTreeTrainParam const &tparam,
size_t layer_begin,
size_t layer_end) {
bst_group_t groups = model.learner_model_param->num_output_group;
uint32_t tree_begin = layer_begin * groups * tparam.num_parallel_tree;
uint32_t tree_end = layer_end * groups * tparam.num_parallel_tree;
if (tree_end == 0) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
CHECK_LT(tree_begin, tree_end);
return {tree_begin, tree_end};
}
// Call fn for each pair of input output tree. Return true if index is out of bound.
template <typename Func>
inline bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step,
GBTreeModel const &model, GBTreeTrainParam const &tparam,
uint32_t layer_trees, Func fn) {
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model, tparam, layer_begin, layer_end);
if (tree_end > model.trees.size()) {
return true;
}
layer_end = layer_end == 0 ? model.trees.size() / layer_trees : layer_end;
uint32_t n_layers = (layer_end - layer_begin) / step;
int32_t in_it = tree_begin;
int32_t out_it = 0;
for (uint32_t l = 0; l < n_layers; ++l) {
for (uint32_t i = 0; i < layer_trees; ++i) {
CHECK_LT(in_it, tree_end);
fn(in_it, out_it);
out_it++;
in_it++;
}
in_it += (step - 1) * layer_trees;
}
return false;
}
} // namespace detail
// gradient boosted trees
class GBTree : public GradientBooster {
public:
@ -200,6 +244,15 @@ class GBTree : public GradientBooster {
return model_.learner_model_param->num_output_group == 1;
}
// Number of trees per layer.
auto LayerTrees() const {
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
return n_trees;
}
// slice the trees, out must be already allocated
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const override;
void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds,
bool training,
@ -210,13 +263,8 @@ class GBTree : public GradientBooster {
uint32_t layer_begin,
unsigned layer_end) const override {
CHECK(configured_);
// From here on, layer becomes concrete trees.
bst_group_t groups = model_.learner_model_param->num_output_group;
uint32_t tree_begin = layer_begin * groups * tparam_.num_parallel_tree;
uint32_t tree_end = layer_end * groups * tparam_.num_parallel_tree;
if (tree_end == 0 || tree_end > model_.trees.size()) {
tree_end = static_cast<uint32_t>(model_.trees.size());
}
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
tree_begin, tree_end);
}

View File

@ -6,10 +6,10 @@
#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "gbtree_model.h"
#include "gbtree.h"
namespace xgboost {
namespace gbm {
void GBTreeModel::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2019 by Contributors
* Copyright 2017-2020 by Contributors
* \file gbtree_model.h
*/
#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
@ -22,6 +22,7 @@ namespace xgboost {
class Json;
namespace gbm {
/*! \brief model parameters */
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
public:

View File

@ -971,6 +971,26 @@ class LearnerImpl : public LearnerIO {
return gbm_->DumpModel(fmap, with_stats, format);
}
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
bool *out_of_bound) override {
this->Configure();
CHECK_GE(begin_layer, 0);
auto *out_impl = new LearnerImpl({});
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
this->tparam_.booster, &this->generic_parameters_,
&this->learner_model_param_));
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
out_impl->gbm_ = std::move(gbm);
Json config { Object() };
this->SaveConfig(&config);
out_impl->mparam_ = this->mparam_;
out_impl->attributes_ = this->attributes_;
out_impl->learner_model_param_ = this->learner_model_param_;
out_impl->LoadConfig(config);
out_impl->Configure();
return out_impl;
}
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter);

View File

@ -154,9 +154,9 @@ TEST(GBTree, JsonIO) {
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
auto const& gbtree_model = model["model"]["model"];
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1ul);
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1ul);
auto j_train_param = model["config"]["gbtree_train_param"];
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
@ -194,7 +194,7 @@ TEST(Dart, JsonIO) {
ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0ul);
}
TEST(Dart, Prediction) {
@ -230,4 +230,122 @@ TEST(Dart, Prediction) {
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
}
}
std::pair<Json, Json> TestModelSlice(std::string booster) {
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses);
int32_t kIters = 10;
std::unique_ptr<Learner> learner {
Learner::Create({m})
};
learner->SetParams(Args{{"booster", booster},
{"tree_method", "hist"},
{"num_parallel_tree", std::to_string(kForest)},
{"num_class", std::to_string(kClasses)},
{"subsample", "0.5"},
{"max_depth", "2"}});
for (auto i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, m);
}
Json model{Object()};
Json config{Object()};
learner->SaveModel(&model);
learner->SaveConfig(&config);
bool out_of_bound = false;
size_t constexpr kSliceStart = 2, kSliceEnd = 8, kStep = 3;
std::unique_ptr<Learner> sliced {learner->Slice(kSliceStart, kSliceEnd, kStep, &out_of_bound)};
Json sliced_model{Object()};
sliced->SaveModel(&sliced_model);
auto get_shape = [&](Json const& model) {
if (booster == "gbtree") {
return get<Object const>(model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]);
} else {
return get<Object const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["gbtree_model_param"]);
}
};
auto const& model_shape = get_shape(sliced_model);
CHECK_EQ(get<String const>(model_shape.at("num_trees")), std::to_string(2 * kClasses * kForest));
Json sliced_config {Object()};
sliced->SaveConfig(&sliced_config);
CHECK_EQ(sliced_config, config);
auto get_trees = [&](Json const& model) {
if (booster == "gbtree") {
return get<Array const>(model["learner"]["gradient_booster"]["model"]["trees"]);
} else {
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
}
};
auto get_info = [&](Json const& model) {
if (booster == "gbtree") {
return get<Array const>(model["learner"]["gradient_booster"]["model"]["tree_info"]);
} else {
return get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["tree_info"]);
}
};
auto const &sliced_trees = get_trees(sliced_model);
CHECK_EQ(sliced_trees.size(), 2 * kClasses * kForest);
auto constexpr kLayerSize = kClasses * kForest;
auto const &sliced_info = get_info(sliced_model);
for (size_t layer = 0; layer < 2; ++layer) {
for (size_t j = 0; j < kClasses; ++j) {
for (size_t k = 0; k < kForest; ++k) {
auto idx = layer * kLayerSize + j * kForest + k;
auto const &group = get<Integer const>(sliced_info.at(idx));
CHECK_EQ(static_cast<size_t>(group), j);
}
}
}
auto const& trees = get_trees(model);
// Sliced layers are [2, 5]
auto begin = kLayerSize * kSliceStart;
auto end = begin + kLayerSize;
auto j = 0;
for (size_t i = begin; i < end; ++i) {
Json tree = trees[i];
tree["id"] = Integer(0); // id is different, we set it to 0 to allow comparison.
auto sliced_tree = sliced_trees[j];
sliced_tree["id"] = Integer(0);
CHECK_EQ(tree, sliced_tree);
j++;
}
begin = kLayerSize * (kSliceStart + kStep);
end = begin + kLayerSize;
for (size_t i = begin; i < end; ++i) {
Json tree = trees[i];
tree["id"] = Integer(0);
auto sliced_tree = sliced_trees[j];
sliced_tree["id"] = Integer(0);
CHECK_EQ(tree, sliced_tree);
j++;
}
return std::make_pair(model, sliced_model);
}
TEST(GBTree, Slice) {
TestModelSlice("gbtree");
}
TEST(Dart, Slice) {
Json model, sliced_model;
std::tie(model, sliced_model) = TestModelSlice("dart");
auto const& weights = get<Array const>(model["learner"]["gradient_booster"]["weight_drop"]);
auto const& trees = get<Array const>(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]);
ASSERT_EQ(weights.size(), trees.size());
}
} // namespace xgboost

View File

@ -118,7 +118,7 @@ TEST(Learner, Configuration) {
// eval_metric is not part of configuration
auto attr_names = learner->GetConfigurationArguments();
ASSERT_EQ(attr_names.size(), 1);
ASSERT_EQ(attr_names.size(), 1ul);
ASSERT_EQ(attr_names.find(emetric), attr_names.cend());
ASSERT_EQ(attr_names.at("foo"), "bar");
}
@ -127,7 +127,7 @@ TEST(Learner, Configuration) {
std::unique_ptr<Learner> learner { Learner::Create({nullptr}) };
learner->SetParams({{"foo", "bar"}, {emetric, "auc"}, {emetric, "entropy"}, {emetric, "KL"}});
auto attr_names = learner->GetConfigurationArguments();
ASSERT_EQ(attr_names.size(), 1);
ASSERT_EQ(attr_names.size(), 1ul);
ASSERT_EQ(attr_names.at("foo"), "bar");
}
}
@ -181,7 +181,7 @@ TEST(Learner, JsonModelIO) {
learner->SaveModel(&new_in);
ASSERT_TRUE(IsA<Object>(out["learner"]["attributes"]));
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1);
ASSERT_EQ(get<Object>(out["learner"]["attributes"]).size(), 1ul);
ASSERT_EQ(out, new_in);
}
}
@ -333,5 +333,4 @@ TEST(Learner, Seed) {
ASSERT_EQ(std::to_string(seed),
get<String>(config["learner"]["generic_param"]["seed"]));
}
} // namespace xgboost

View File

@ -29,7 +29,7 @@ def json_model(model_path, parameters):
return model
class TestModels(unittest.TestCase):
class TestModels:
def test_glm(self):
param = {'verbosity': 0, 'objective': 'binary:logistic',
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
@ -209,12 +209,14 @@ class TestModels(unittest.TestCase):
bst = xgb.train([], dm1)
bst.predict(dm1) # success
self.assertRaises(ValueError, bst.predict, dm2)
with pytest.raises(ValueError):
bst.predict(dm2)
bst.predict(dm1) # success
bst = xgb.train([], dm2)
bst.predict(dm2) # success
self.assertRaises(ValueError, bst.predict, dm1)
with pytest.raises(ValueError):
bst.predict(dm1)
bst.predict(dm2) # success
def test_model_binary_io(self):
@ -325,3 +327,96 @@ class TestModels(unittest.TestCase):
parameters = {'tree_method': 'hist', 'booster': 'dart',
'objective': 'multi:softmax'}
validate_model(parameters)
@pytest.mark.parametrize('booster', ['gbtree', 'dart'])
def test_slice(self, booster):
from sklearn.datasets import make_classification
num_classes = 3
X, y = make_classification(n_samples=1000, n_informative=5,
n_classes=num_classes)
dtrain = xgb.DMatrix(data=X, label=y)
num_parallel_tree = 4
num_boost_round = 16
total_trees = num_parallel_tree * num_classes * num_boost_round
booster = xgb.train({
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
'objective': 'multi:softprob'},
num_boost_round=num_boost_round, dtrain=dtrain)
assert len(booster.get_dump()) == total_trees
beg = 3
end = 7
sliced: xgb.Booster = booster[beg: end]
sliced_trees = (end - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
sliced_trees = sliced_trees // 2
sliced: xgb.Booster = booster[beg: end: 2]
assert sliced_trees == len(sliced.get_dump())
sliced: xgb.Booster = booster[beg: ...]
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
sliced: xgb.Booster = booster[beg:]
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
sliced: xgb.Booster = booster[:end]
sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
sliced: xgb.Booster = booster[...:end]
sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())
with pytest.raises(ValueError, match=r'>= 0'):
booster[-1: 0]
# we do not accept empty slice.
with pytest.raises(ValueError):
booster[1:1]
# stop can not be smaller than begin
with pytest.raises(ValueError, match=r'Invalid.*'):
booster[3:0]
with pytest.raises(ValueError, match=r'Invalid.*'):
booster[3:-1]
# negative step is not supported.
with pytest.raises(ValueError, match=r'.*>= 1.*'):
booster[0:2:-1]
# step can not be 0.
with pytest.raises(ValueError, match=r'.*>= 1.*'):
booster[0:2:0]
trees = [_ for _ in booster]
assert len(trees) == num_boost_round
with pytest.raises(TypeError):
booster["wrong type"]
with pytest.raises(IndexError):
booster[:num_boost_round+1]
with pytest.raises(ValueError):
booster[1, 2] # too many dims
# setitem is not implemented as model is immutable during slicing.
with pytest.raises(TypeError):
booster[...:end] = booster
sliced_0 = booster[1:3]
sliced_1 = booster[3:7]
predt_0 = sliced_0.predict(dtrain, output_margin=True)
predt_1 = sliced_1.predict(dtrain, output_margin=True)
merged = predt_0 + predt_1 - 0.5 # base score.
single = booster[1:7].predict(dtrain, output_margin=True)
np.testing.assert_allclose(merged, single, atol=1e-6)
sliced_0 = booster[1:7:2] # 1,3,5
sliced_1 = booster[2:8:2] # 2,4,6
predt_0 = sliced_0.predict(dtrain, output_margin=True)
predt_1 = sliced_1.predict(dtrain, output_margin=True)
merged = predt_0 + predt_1 - 0.5
single = booster[1:7].predict(dtrain, output_margin=True)
np.testing.assert_allclose(merged, single, atol=1e-6)

View File

@ -113,6 +113,35 @@ class TestCallbacks(unittest.TestCase):
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_early_stopping_save_best_model(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
n_estimators = 100
cls = xgb.XGBClassifier(n_estimators=n_estimators)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
cls.fit(X, y, eval_set=[(X, y)],
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
cls = xgb.XGBClassifier(booster='gblinear', n_estimators=10)
self.assertRaises(ValueError, lambda: cls.fit(X, y, eval_set=[(X, y)],
eval_metric=tm.eval_error_metric,
callbacks=[early_stop]))
# No error
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=False)
xgb.XGBClassifier(booster='gblinear', n_estimators=10).fit(
X, y, eval_set=[(X, y)],
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
def run_eta_decay(self, tree_method, deprecated_callback):
if deprecated_callback:
scheduler = xgb.callback.reset_learning_rate