Tests and documents for new JSON routines. (#5120)
This commit is contained in:
parent
63ffd2f686
commit
27b3646d29
@ -10,6 +10,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
|
||||
:caption: Contents:
|
||||
|
||||
model
|
||||
saving_model
|
||||
Distributed XGBoost with AWS YARN <aws_yarn>
|
||||
kubernetes
|
||||
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
|
||||
|
||||
195
doc/tutorials/saving_model.rst
Normal file
195
doc/tutorials/saving_model.rst
Normal file
@ -0,0 +1,195 @@
|
||||
########################
|
||||
Introduction to Model IO
|
||||
########################
|
||||
|
||||
In XGBoost 1.0.0, we introduced experimental support of using `JSON
|
||||
<https://www.json.org/json-en.html>`_ for saving/loading XGBoost models and related
|
||||
hyper-parameters for training, aiming to replace the old binary internal format with an
|
||||
open format that can be easily reused. The support for binary format will be continued in
|
||||
the future until JSON format is no-longer experimental and has satisfying performance.
|
||||
This tutorial aims to share some basic insights into the JSON serialisation method used in
|
||||
XGBoost. Without explicitly mentioned, the following sections assume you are using the
|
||||
experimental JSON format, which can be enabled by passing
|
||||
``enable_experimental_json_serialization=True`` as training parameter, or provide the file
|
||||
name with ``.json`` as file extension when saving/loading model:
|
||||
``booster.save_model('model.json')``. More details below.
|
||||
|
||||
Before we get started, XGBoost is a gradient boosting library with focus on tree model,
|
||||
which means inside XGBoost, there are 2 distinct parts: the model consisted of trees and
|
||||
algorithms used to build it. If you come from Deep Learning community, then it should be
|
||||
clear to you that there are differences between the neural network structures composed of
|
||||
weights with fixed tensor operations, and the optimizers (like RMSprop) used to train
|
||||
them.
|
||||
|
||||
So when one calls ``booster.save_model``, XGBoost saves the trees, some model parameters
|
||||
like number of input columns in trained trees, and the objective function, which combined
|
||||
to represent the concept of "model" in XGBoost. As for why are we saving the objective as
|
||||
part of model, that's because objective controls transformation of global bias (called
|
||||
``base_score`` in XGBoost). Users can share this model with others for prediction,
|
||||
evaluation or continue the training with a different set of hyper-parameters etc.
|
||||
However, this is not the end of story. There are cases where we need to save something
|
||||
more than just the model itself. For example, in distrbuted training, XGBoost performs
|
||||
checkpointing operation. Or for some reasons, your favorite distributed computing
|
||||
framework decide to copy the model from one worker to another and continue the training in
|
||||
there. In such cases, the serialisation output is required to contain enougth information
|
||||
to continue previous training without user providing any parameters again. We consider
|
||||
such scenario as memory snapshot (or memory based serialisation method) and distinguish it
|
||||
with normal model IO operation. In Python, this can be invoked by pickling the
|
||||
``Booster`` object. Other language bindings are still working in progress.
|
||||
|
||||
.. note::
|
||||
|
||||
The old binary format doesn't distinguish difference between model and raw memory
|
||||
serialisation format, it's a mix of everything, which is part of the reason why we want
|
||||
to replace it with a more robust serialisation method. JVM Package has its own memory
|
||||
based serialisation methods.
|
||||
|
||||
To enable JSON format support for model IO (saving only the trees and objective), provide
|
||||
a filename with ``.json`` as file extension:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
bst.save_model('model_file_name.json')
|
||||
|
||||
While for enabling JSON as memory based serialisation format, pass
|
||||
``enable_experimental_json_serialization`` as a training parameter. In Python this can be
|
||||
done by:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
bst = xgboost.train({'enable_experimental_json_serialization': True}, dtrain)
|
||||
with open('filename', 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
|
||||
Notice the ``filename`` is for Python intrinsic function ``open``, not for XGBoost. Hence
|
||||
parameter ``enable_experimental_json_serialization`` is required to enable JSON format.
|
||||
As the name suggested, memory based serialisation captures many stuffs internal to
|
||||
XGBoost, so it's only suitable to be used for checkpoints, which doesn't require stable
|
||||
output format. That being said, loading pickled booster (memory snapshot) in a different
|
||||
XGBoost version may lead to errors or undefined behaviors. But we promise the stable
|
||||
output format of binary model and JSON model (once it's no-longer experimental) as they
|
||||
are designed to be reusable. This scheme fits as Python itself doesn't guarantee pickled
|
||||
bytecode can be used in different Python version.
|
||||
|
||||
***************************
|
||||
Custom objective and metric
|
||||
***************************
|
||||
|
||||
XGBoost accepts user provided objective and metric functions as an extension. These
|
||||
functions are not saved in model file as they are language dependent feature. With
|
||||
Python, user can pickle the model to include these functions in saved binary. One
|
||||
drawback is, the output from pickle is not a stable serialization format and doesn't work
|
||||
on different Python version or XGBoost version, not to mention different language
|
||||
environment. Another way to workaround this limitation is to provide these functions
|
||||
again after the model is loaded. If the customized function is useful, please consider
|
||||
making a PR for implementing it inside XGBoost, this way we can have your functions
|
||||
working with different language bindings.
|
||||
|
||||
********************************************************
|
||||
Saving and Loading the internal parameters configuration
|
||||
********************************************************
|
||||
|
||||
XGBoost's ``C API`` and ``Python API`` supports saving and loading the internal
|
||||
configuration directly as a JSON string. In Python package:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
bst = xgboost.train(...)
|
||||
config = bst.save_config()
|
||||
print(config)
|
||||
|
||||
Will print out something similiar to (not actual output as it's too long for demonstration):
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"Learner": {
|
||||
"generic_parameter": {
|
||||
"enable_experimental_json_serialization": "0",
|
||||
"gpu_id": "0",
|
||||
"gpu_page_size": "0",
|
||||
"n_jobs": "0",
|
||||
"random_state": "0",
|
||||
"seed": "0",
|
||||
"seed_per_iteration": "0"
|
||||
},
|
||||
"gradient_booster": {
|
||||
"gbtree_train_param": {
|
||||
"num_parallel_tree": "1",
|
||||
"predictor": "gpu_predictor",
|
||||
"process_type": "default",
|
||||
"tree_method": "gpu_hist",
|
||||
"updater": "grow_gpu_hist",
|
||||
"updater_seq": "grow_gpu_hist"
|
||||
},
|
||||
"name": "gbtree",
|
||||
"updater": {
|
||||
"grow_gpu_hist": {
|
||||
"gpu_hist_train_param": {
|
||||
"debug_synchronize": "0",
|
||||
"gpu_batch_nrows": "0",
|
||||
"single_precision_histogram": "0"
|
||||
},
|
||||
"train_param": {
|
||||
"alpha": "0",
|
||||
"cache_opt": "1",
|
||||
"colsample_bylevel": "1",
|
||||
"colsample_bynode": "1",
|
||||
"colsample_bytree": "1",
|
||||
"default_direction": "learn",
|
||||
"enable_feature_grouping": "0",
|
||||
"eta": "0.300000012",
|
||||
"gamma": "0",
|
||||
"grow_policy": "depthwise",
|
||||
"interaction_constraints": "",
|
||||
"lambda": "1",
|
||||
"learning_rate": "0.300000012",
|
||||
"max_bin": "256",
|
||||
"max_conflict_rate": "0",
|
||||
"max_delta_step": "0",
|
||||
"max_depth": "6",
|
||||
"max_leaves": "0",
|
||||
"max_search_group": "100",
|
||||
"refresh_leaf": "1",
|
||||
"sketch_eps": "0.0299999993",
|
||||
"sketch_ratio": "2",
|
||||
"subsample": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"learner_train_param": {
|
||||
"booster": "gbtree",
|
||||
"disable_default_eval_metric": "0",
|
||||
"dsplit": "auto",
|
||||
"objective": "reg:squarederror"
|
||||
},
|
||||
"metrics": [],
|
||||
"objective": {
|
||||
"name": "reg:squarederror",
|
||||
"reg_loss_param": {
|
||||
"scale_pos_weight": "1"
|
||||
}
|
||||
}
|
||||
},
|
||||
"version": [1, 0, 0]
|
||||
}
|
||||
|
||||
|
||||
You can load it back to the model generated by same version of XGBoost by:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
bst.load_config(config)
|
||||
|
||||
This way users can study the internal representation more closely.
|
||||
|
||||
************
|
||||
Future Plans
|
||||
************
|
||||
|
||||
Right now using the JSON format incurs longer serialisation time, we have been working on
|
||||
optimizing the JSON implementation to close the gap between binary format and JSON format.
|
||||
You can track the progress in `#5046 <https://github.com/dmlc/xgboost/pull/5046>`_.
|
||||
Another important item for JSON format support is a stable and documented `schema
|
||||
<https://json-schema.org/>`_, based on which one can easily reuse the saved model.
|
||||
@ -426,6 +426,24 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
unsigned ntree_limit,
|
||||
bst_ulong *out_len,
|
||||
const float **out_result);
|
||||
/*
|
||||
* Short note for serialization APIs. There are 3 different sets of serialization API.
|
||||
*
|
||||
* - Functions with the term "Model" handles saving/loading XGBoost model like trees or
|
||||
* linear weights. Striping out parameters configuration like training algorithms or
|
||||
* CUDA device ID helps user to reuse the trained model for different tasks, examples
|
||||
* are prediction, training continuation or interpretation.
|
||||
*
|
||||
* - Functions with the term "Config" handles save/loading configuration. It helps user
|
||||
* to study the internal of XGBoost. Also user can use the load method for specifying
|
||||
* paramters in a structured way. These functions are introduced in 1.0.0, and are not
|
||||
* yet stable.
|
||||
*
|
||||
* - Functions with the term "Serialization" are combined of above two. They are used in
|
||||
* situations like check-pointing, or continuing training task in distributed
|
||||
* environment. In these cases the task must be carried out without any user
|
||||
* intervention.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \brief Load model from existing file
|
||||
@ -506,7 +524,10 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Save XGBoost's internal configuration into a JSON document.
|
||||
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
|
||||
* support is experimental, function signature may change in the future without
|
||||
* notice.
|
||||
*
|
||||
* \param handle handle to Booster object.
|
||||
* \param out_str A valid pointer to array of characters. The characters array is
|
||||
* allocated and managed by XGBoost, while pointer to that array needs to
|
||||
@ -516,7 +537,10 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len,
|
||||
char const **out_str);
|
||||
/*!
|
||||
* \brief Load XGBoost's internal configuration from a JSON document.
|
||||
* \brief Load XGBoost's internal configuration from a JSON document. Currently the
|
||||
* support is experimental, function signature may change in the future without
|
||||
* notice.
|
||||
*
|
||||
* \param handle handle to Booster object.
|
||||
* \param json_parameters string representation of a JSON document.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
|
||||
@ -472,7 +472,10 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
// Save model into binary format. The code is about to be deprecated by more robust
|
||||
// JSON serialization format.
|
||||
// JSON serialization format. This function is uneffected by
|
||||
// `enable_experimental_json_serialization` as user might enable this flag for pickle
|
||||
// while still want a binary output. As we are progressing at replacing the binary
|
||||
// format, there's no need to put too much effort on it.
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
|
||||
std::vector<std::pair<std::string, std::string> > extra_attr;
|
||||
|
||||
640
tests/cpp/test_serialization.cc
Normal file
640
tests/cpp/test_serialization.cc
Normal file
@ -0,0 +1,640 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <string>
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/base.h>
|
||||
#include "helpers.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/random.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr<DMatrix> p_dmat) {
|
||||
for (auto& batch : p_dmat->GetBatches<SparsePage>()) {
|
||||
batch.data.HostVector();
|
||||
batch.offset.HostVector();
|
||||
}
|
||||
|
||||
int32_t constexpr kIters = 2;
|
||||
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
std::string const fname = tempdir.path + "/model";
|
||||
|
||||
std::vector<std::string> dumped_0;
|
||||
std::string model_at_kiter;
|
||||
|
||||
{
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
std::unique_ptr<Learner> learner {Learner::Create({p_dmat})};
|
||||
learner->SetParams(args);
|
||||
for (int32_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
dumped_0 = learner->DumpModel(fmap, true, "json");
|
||||
learner->Save(fo.get());
|
||||
|
||||
common::MemoryBufferStream mem_out(&model_at_kiter);
|
||||
learner->Save(&mem_out);
|
||||
}
|
||||
|
||||
std::vector<std::string> dumped_1;
|
||||
{
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
||||
std::unique_ptr<Learner> learner {Learner::Create({p_dmat})};
|
||||
learner->Load(fi.get());
|
||||
learner->Configure();
|
||||
dumped_1 = learner->DumpModel(fmap, true, "json");
|
||||
}
|
||||
ASSERT_EQ(dumped_0, dumped_1);
|
||||
|
||||
std::string model_at_2kiter;
|
||||
|
||||
// Test training continuation with data from host
|
||||
{
|
||||
std::string continued_model;
|
||||
{
|
||||
// Continue the previous training with another kIters
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(fname.c_str(), "r"));
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||
learner->Load(fi.get());
|
||||
learner->Configure();
|
||||
|
||||
// verify the loaded model doesn't change.
|
||||
std::string serialised_model_tmp;
|
||||
common::MemoryBufferStream mem_out(&serialised_model_tmp);
|
||||
learner->Save(&mem_out);
|
||||
ASSERT_EQ(model_at_kiter, serialised_model_tmp);
|
||||
|
||||
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
||||
batch.data.HostVector();
|
||||
batch.offset.HostVector();
|
||||
}
|
||||
|
||||
for (int32_t iter = kIters; iter < 2 * kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
common::MemoryBufferStream fo(&continued_model);
|
||||
learner->Save(&fo);
|
||||
}
|
||||
|
||||
{
|
||||
// Train 2 * kIters in one go
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||
learner->SetParams(args);
|
||||
for (int32_t iter = 0; iter < 2 * kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
|
||||
// Verify model is same at the same iteration during two training
|
||||
// sessions.
|
||||
if (iter == kIters - 1) {
|
||||
std::string reproduced_model;
|
||||
common::MemoryBufferStream fo(&reproduced_model);
|
||||
learner->Save(&fo);
|
||||
ASSERT_EQ(model_at_kiter, reproduced_model);
|
||||
}
|
||||
}
|
||||
common::MemoryBufferStream fo(&model_at_2kiter);
|
||||
learner->Save(&fo);
|
||||
}
|
||||
Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()});
|
||||
Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()});
|
||||
ASSERT_EQ(m_0, m_1);
|
||||
}
|
||||
|
||||
// Test training continuation with data from device.
|
||||
{
|
||||
// Continue the previous training but on data from device.
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||
learner->Load(fi.get());
|
||||
learner->Configure();
|
||||
|
||||
// verify the loaded model doesn't change.
|
||||
std::string serialised_model_tmp;
|
||||
common::MemoryBufferStream mem_out(&serialised_model_tmp);
|
||||
learner->Save(&mem_out);
|
||||
ASSERT_EQ(model_at_kiter, serialised_model_tmp);
|
||||
|
||||
learner->SetParam("gpu_id", "0");
|
||||
// Pull data to device
|
||||
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(0);
|
||||
batch.data.DeviceSpan();
|
||||
batch.offset.SetDevice(0);
|
||||
batch.offset.DeviceSpan();
|
||||
}
|
||||
|
||||
for (int32_t iter = kIters; iter < 2 * kIters; ++iter) {
|
||||
learner->UpdateOneIter(iter, p_dmat.get());
|
||||
}
|
||||
serialised_model_tmp = std::string{};
|
||||
common::MemoryBufferStream fo(&serialised_model_tmp);
|
||||
learner->Save(&fo);
|
||||
|
||||
Json m_0 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()});
|
||||
Json m_1 = Json::Load(StringView{serialised_model_tmp.c_str(), serialised_model_tmp.size()});
|
||||
// GPU ID is changed as data is coming from device.
|
||||
ASSERT_EQ(get<Object>(m_0["Config"]["learner"]["generic_param"]).erase("gpu_id"),
|
||||
get<Object>(m_1["Config"]["learner"]["generic_param"]).erase("gpu_id"));
|
||||
}
|
||||
}
|
||||
|
||||
// Binary is not tested, as it is NOT reproducible.
|
||||
class SerializationTest : public ::testing::Test {
|
||||
protected:
|
||||
size_t constexpr static kRows = 10;
|
||||
size_t constexpr static kCols = 10;
|
||||
std::shared_ptr<DMatrix>* pp_dmat_;
|
||||
FeatureMap fmap_;
|
||||
|
||||
protected:
|
||||
~SerializationTest() override {
|
||||
delete pp_dmat_;
|
||||
}
|
||||
void SetUp() override {
|
||||
pp_dmat_ = CreateDMatrix(kRows, kCols, .5f);
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{*pp_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat->Info().labels_.HostVector();
|
||||
|
||||
xgboost::SimpleLCG gen(0);
|
||||
SimpleRealUniformDistribution<float> dis(0.0f, 1.0f);
|
||||
|
||||
for (auto& v : h_labels) { v = dis(&gen); }
|
||||
|
||||
for (size_t i = 0; i < kCols; ++i) {
|
||||
std::string name = "feat_" + std::to_string(i);
|
||||
fmap_.PushBack(i, name.c_str(), "q");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SerializationTest, Exact) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(SerializationTest, Approx) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(SerializationTest, Hist) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(SerializationTest, CPU_CoordDescent) {
|
||||
TestLearnerSerialization({{"booster", "gblinear"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"updater", "coord_descent"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(SerializationTest, GPU_Hist) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"seed", "0"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(SerializationTest, ConfigurationCount) {
|
||||
auto& p_dmat = *pp_dmat_;
|
||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_dmat};
|
||||
|
||||
xgboost::ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
|
||||
testing::internal::CaptureStderr();
|
||||
|
||||
std::string model_str;
|
||||
{
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
|
||||
learner->SetParam("tree_method", "gpu_hist");
|
||||
learner->SetParam("enable_experimental_json_serialization", "1");
|
||||
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
learner->UpdateOneIter(i, p_dmat.get());
|
||||
}
|
||||
common::MemoryBufferStream fo(&model_str);
|
||||
learner->Save(&fo);
|
||||
}
|
||||
|
||||
{
|
||||
common::MemoryBufferStream fi(&model_str);
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
learner->Load(&fi);
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
learner->UpdateOneIter(i, p_dmat.get());
|
||||
}
|
||||
}
|
||||
|
||||
std::string output = testing::internal::GetCapturedStderr();
|
||||
std::string target = "[GPU Hist]: Configure";
|
||||
ASSERT_NE(output.find(target), std::string::npos);
|
||||
|
||||
size_t occureences = 0;
|
||||
size_t pos = 0;
|
||||
// Should run configuration exactly 2 times, one for each learner.
|
||||
while ((pos = output.find("[GPU Hist]: Configure", pos)) != std::string::npos) {
|
||||
occureences ++;
|
||||
pos += target.size();
|
||||
}
|
||||
ASSERT_EQ(occureences, 2);
|
||||
|
||||
xgboost::ConsoleLogger::Configure({{"verbosity", "1"}});
|
||||
}
|
||||
|
||||
TEST_F(SerializationTest, GPU_CoordDescent) {
|
||||
TestLearnerSerialization({{"booster", "gblinear"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"updater", "gpu_coord_descent"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
|
||||
class LogitSerializationTest : public SerializationTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
pp_dmat_ = CreateDMatrix(kRows, kCols, .5f);
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{*pp_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat->Info().labels_.HostVector();
|
||||
|
||||
std::bernoulli_distribution flip(0.5);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
rnd.seed(0);
|
||||
|
||||
for (auto& v : h_labels) { v = flip(rnd); }
|
||||
|
||||
for (size_t i = 0; i < kCols; ++i) {
|
||||
std::string name = "feat_" + std::to_string(i);
|
||||
fmap_.PushBack(i, name.c_str(), "q");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LogitSerializationTest, Exact) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(LogitSerializationTest, Approx) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(LogitSerializationTest, Hist) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(LogitSerializationTest, CPU_CoordDescent) {
|
||||
TestLearnerSerialization({{"booster", "gblinear"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"updater", "coord_descent"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(LogitSerializationTest, GPU_Hist) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", "2"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(LogitSerializationTest, GPU_CoordDescent) {
|
||||
TestLearnerSerialization({{"booster", "gblinear"},
|
||||
{"objective", "binary:logistic"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"updater", "gpu_coord_descent"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
class MultiClassesSerializationTest : public SerializationTest {
|
||||
protected:
|
||||
size_t constexpr static kClasses = 4;
|
||||
|
||||
void SetUp() override {
|
||||
pp_dmat_ = CreateDMatrix(kRows, kCols, .5f);
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{*pp_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat->Info().labels_.HostVector();
|
||||
|
||||
std::uniform_int_distribution<size_t> categorical(0, kClasses - 1);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
rnd.seed(0);
|
||||
|
||||
for (auto& v : h_labels) { v = categorical(rnd); }
|
||||
|
||||
for (size_t i = 0; i < kCols; ++i) {
|
||||
std::string name = "feat_" + std::to_string(i);
|
||||
fmap_.PushBack(i, name.c_str(), "q");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MultiClassesSerializationTest, Exact) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "exact"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(MultiClassesSerializationTest, Approx) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "approx"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(MultiClassesSerializationTest, Hist) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"num_parallel_tree", "4"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(MultiClassesSerializationTest, CPU_CoordDescent) {
|
||||
TestLearnerSerialization({{"booster", "gblinear"},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"updater", "coord_descent"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(MultiClassesSerializationTest, GPU_Hist) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
// GPU_Hist has higher floating point error. 1e-6 doesn't work
|
||||
// after num_parallel_tree goes to 4
|
||||
{"num_parallel_tree", "3"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
|
||||
TestLearnerSerialization({{"booster", "dart"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"tree_method", "gpu_hist"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
|
||||
TEST_F(MultiClassesSerializationTest, GPU_CoordDescent) {
|
||||
TestLearnerSerialization({{"booster", "gblinear"},
|
||||
{"num_class", std::to_string(kClasses)},
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"enable_experimental_json_serialization", "1"},
|
||||
{"updater", "gpu_coord_descent"}},
|
||||
fmap_, *pp_dmat_);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user