Implement JSON IO for updaters (#5094)

* Implement JSON IO for updaters.

* Remove parameters in split evaluator.
This commit is contained in:
Jiaming Yuan
2019-12-07 00:24:00 +08:00
committed by GitHub
parent 2dcb62ddfb
commit 7ef5b78003
14 changed files with 145 additions and 92 deletions

View File

@@ -11,6 +11,7 @@
#include "../helpers.h"
#include "gtest/gtest.h"
#include "xgboost/json.h"
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist.cu"
@@ -424,5 +425,24 @@ TEST(GpuHist, ExternalMemory) {
}
}
TEST(GpuHist, Config_IO) {
GenericParameter generic_param(CreateEmptyGenericParam(0));
std::unique_ptr<TreeUpdater> updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) };
updater->Configure(Args{});
Json j_updater { Object() };
updater->SaveConfig(&j_updater);
ASSERT_TRUE(IsA<Object>(j_updater["gpu_hist_train_param"]));
ASSERT_TRUE(IsA<Object>(j_updater["train_param"]));
updater->LoadConfig(j_updater);
Json j_updater_roundtrip { Object() };
updater->SaveConfig(&j_updater_roundtrip);
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["gpu_hist_train_param"]));
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["train_param"]));
ASSERT_EQ(j_updater, j_updater_roundtrip);
}
} // namespace tree
} // namespace xgboost

View File

@@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker {
}
// Initialize split evaluator
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
evaluator->Init({});
evaluator->Init(&param_);
// Now enumerate all feature*threshold combination to get best split
// To simplify logic, we make some assumptions:
@@ -235,6 +235,7 @@ class QuantileHistMock : public QuantileHistMaker {
const std::vector<std::pair<std::string, std::string> >& args) :
cfg_{args} {
QuantileHistMaker::Configure(args);
spliteval_->Init(&param_);
builder_.reset(
new BuilderMock(
param_,