Implement JSON IO for updaters (#5094)

* Implement JSON IO for updaters.

* Remove parameters in split evaluator.
This commit is contained in:
Jiaming Yuan
2019-12-07 00:24:00 +08:00
committed by GitHub
parent 2dcb62ddfb
commit 7ef5b78003
14 changed files with 145 additions and 92 deletions

View File

@@ -18,6 +18,7 @@
#include "xgboost/host_device_vector.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "xgboost/json.h"
#include "../common/common.h"
#include "../common/compressed_iterator.h"
@@ -1028,7 +1029,6 @@ class GPUHistMakerSpecialised {
hist_maker_param_.UpdateAllowUnknown(args);
device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device";
dh::CheckComputeCapability();
monitor_.Init("updater_gpu_hist");
@@ -1129,8 +1129,7 @@ class GPUHistMakerSpecialised {
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
@@ -1141,8 +1140,8 @@ class GPUHistMakerSpecialised {
return true;
}
TrainParam param_; // NOLINT
MetaInfo* info_{}; // NOLINT
TrainParam param_; // NOLINT
MetaInfo* info_{}; // NOLINT
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
@@ -1175,6 +1174,27 @@ class GPUHistMaker : public TreeUpdater {
}
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
fromJson(config.at("train_param"), &float_maker_->param_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
fromJson(config.at("train_param"), &double_maker_->param_);
}
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = toJson(hist_maker_param_);
if (hist_maker_param_.single_precision_histogram) {
out["train_param"] = toJson(float_maker_->param_);
} else {
out["train_param"] = toJson(double_maker_->param_);
}
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
if (hist_maker_param_.single_precision_histogram) {