Implement JSON IO for updaters (#5094)
* Implement JSON IO for updaters. * Remove parameters in split evaluator.
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "../../../src/data/sparse_page_source.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../../../src/tree/updater_gpu_hist.cu"
|
||||
@@ -424,5 +425,24 @@ TEST(GpuHist, ExternalMemory) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, Config_IO) {
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
std::unique_ptr<TreeUpdater> updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) };
|
||||
updater->Configure(Args{});
|
||||
|
||||
Json j_updater { Object() };
|
||||
updater->SaveConfig(&j_updater);
|
||||
ASSERT_TRUE(IsA<Object>(j_updater["gpu_hist_train_param"]));
|
||||
ASSERT_TRUE(IsA<Object>(j_updater["train_param"]));
|
||||
updater->LoadConfig(j_updater);
|
||||
|
||||
Json j_updater_roundtrip { Object() };
|
||||
updater->SaveConfig(&j_updater_roundtrip);
|
||||
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["gpu_hist_train_param"]));
|
||||
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["train_param"]));
|
||||
|
||||
ASSERT_EQ(j_updater, j_updater_roundtrip);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
// Initialize split evaluator
|
||||
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
|
||||
evaluator->Init({});
|
||||
evaluator->Init(¶m_);
|
||||
|
||||
// Now enumerate all feature*threshold combination to get best split
|
||||
// To simplify logic, we make some assumptions:
|
||||
@@ -235,6 +235,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
const std::vector<std::pair<std::string, std::string> >& args) :
|
||||
cfg_{args} {
|
||||
QuantileHistMaker::Configure(args);
|
||||
spliteval_->Init(¶m_);
|
||||
builder_.reset(
|
||||
new BuilderMock(
|
||||
param_,
|
||||
|
||||
Reference in New Issue
Block a user