JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan
2019-12-15 17:31:53 +08:00
committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
24 changed files with 761 additions and 390 deletions

View File

@@ -682,13 +682,13 @@ void RegTree::LoadModel(Json const& in) {
s.leaf_child_cnt = get<Integer const>(leaf_child_counts[i]);
auto& n = nodes_[i];
auto left = get<Integer const>(lefts[i]);
auto right = get<Integer const>(rights[i]);
auto parent = get<Integer const>(parents[i]);
auto ind = get<Integer const>(indices[i]);
auto cond = get<Number const>(conds[i]);
auto dft_left = get<Boolean const>(default_left[i]);
n = Node(left, right, parent, ind, cond, dft_left);
bst_node_t left = get<Integer const>(lefts[i]);
bst_node_t right = get<Integer const>(rights[i]);
bst_node_t parent = get<Integer const>(parents[i]);
bst_feature_t ind = get<Integer const>(indices[i]);
float cond { get<Number const>(conds[i]) };
bool dft_left { get<Boolean const>(default_left[i]) };
n = Node{left, right, parent, ind, cond, dft_left};
}

View File

@@ -1027,8 +1027,6 @@ class GPUHistMakerSpecialised {
param_.UpdateAllowUnknown(args);
generic_param_ = generic_param;
hist_maker_param_.UpdateAllowUnknown(args);
device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device";
dh::CheckComputeCapability();
monitor_.Init("updater_gpu_hist");
@@ -1041,6 +1039,7 @@ class GPUHistMakerSpecialised {
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
monitor_.StartCuda("Update");
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
@@ -1064,6 +1063,8 @@ class GPUHistMakerSpecialised {
}
void InitDataOnce(DMatrix* dmat) {
device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device";
info_ = &dmat->Info();
reducer_.Init({device_});
@@ -1162,14 +1163,24 @@ class GPUHistMakerSpecialised {
class GPUHistMaker : public TreeUpdater {
public:
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
hist_maker_param_.UpdateAllowUnknown(args);
float_maker_.reset();
double_maker_.reset();
// The passed in args can be empty, if we simply purge the old maker without
// preserving parameters then we can't do Update on it.
TrainParam param;
if (float_maker_) {
param = float_maker_->param_;
} else if (double_maker_) {
param = double_maker_->param_;
}
if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
float_maker_->param_ = param;
float_maker_->Configure(args, tparam_);
} else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
double_maker_->param_ = param;
double_maker_->Configure(args, tparam_);
}
}