Implement JSON IO for updaters (#5094)
* Implement JSON IO for updaters. * Remove parameters in split evaluator.
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/compressed_iterator.h"
|
||||
@@ -1028,7 +1029,6 @@ class GPUHistMakerSpecialised {
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
device_ = generic_param_->gpu_id;
|
||||
CHECK_GE(device_, 0) << "Must have at least one device";
|
||||
|
||||
dh::CheckComputeCapability();
|
||||
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
@@ -1129,8 +1129,7 @@ class GPUHistMakerSpecialised {
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
@@ -1141,8 +1140,8 @@ class GPUHistMakerSpecialised {
|
||||
return true;
|
||||
}
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
@@ -1175,6 +1174,27 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
||||
fromJson(config.at("train_param"), &float_maker_->param_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
||||
fromJson(config.at("train_param"), &double_maker_->param_);
|
||||
}
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["gpu_hist_train_param"] = toJson(hist_maker_param_);
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
out["train_param"] = toJson(float_maker_->param_);
|
||||
} else {
|
||||
out["train_param"] = toJson(double_maker_->param_);
|
||||
}
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
|
||||
Reference in New Issue
Block a user