JSON configuration IO. (#5111)
* Add saving/loading JSON configuration. * Implement Python pickle interface with new IO routines. * Basic tests for training continuation.
This commit is contained in:
@@ -1027,8 +1027,6 @@ class GPUHistMakerSpecialised {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
generic_param_ = generic_param;
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
device_ = generic_param_->gpu_id;
|
||||
CHECK_GE(device_, 0) << "Must have at least one device";
|
||||
dh::CheckComputeCapability();
|
||||
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
@@ -1041,6 +1039,7 @@ class GPUHistMakerSpecialised {
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
monitor_.StartCuda("Update");
|
||||
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
@@ -1064,6 +1063,8 @@ class GPUHistMakerSpecialised {
|
||||
}
|
||||
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
device_ = generic_param_->gpu_id;
|
||||
CHECK_GE(device_, 0) << "Must have at least one device";
|
||||
info_ = &dmat->Info();
|
||||
reducer_.Init({device_});
|
||||
|
||||
@@ -1162,14 +1163,24 @@ class GPUHistMakerSpecialised {
|
||||
class GPUHistMaker : public TreeUpdater {
|
||||
public:
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
float_maker_.reset();
|
||||
double_maker_.reset();
|
||||
// The passed in args can be empty, if we simply purge the old maker without
|
||||
// preserving parameters then we can't do Update on it.
|
||||
TrainParam param;
|
||||
if (float_maker_) {
|
||||
param = float_maker_->param_;
|
||||
} else if (double_maker_) {
|
||||
param = double_maker_->param_;
|
||||
}
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
|
||||
float_maker_->param_ = param;
|
||||
float_maker_->Configure(args, tparam_);
|
||||
} else {
|
||||
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
|
||||
double_maker_->param_ = param;
|
||||
double_maker_->Configure(args, tparam_);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user