Fix CLI model IO. (#5535)

* Add test for comparing Python and CLI training result.
This commit is contained in:
Jiaming Yuan
2020-04-16 07:48:47 +08:00
committed by GitHub
parent 0676a19e70
commit 468b1594d3
4 changed files with 100 additions and 9 deletions

View File

@@ -138,14 +138,10 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
// constraint.
if (name_pred == "stdout") {
save_period = 0;
this->cfg.emplace_back(std::make_pair("silent", "0"));
}
if (dsplit == 0 && rabit::IsDistributed()) {
dsplit = 2;
}
if (rabit::GetRank() != 0) {
this->cfg.emplace_back(std::make_pair("silent", "1"));
}
}
};
@@ -189,7 +185,7 @@ void CLITrain(const CLIParam& param) {
if (param.model_in != "NULL") {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->Load(fi.get());
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
} else {
learner->SetParams(param.cfg);
@@ -229,7 +225,7 @@ void CLITrain(const CLIParam& param) {
<< i + 1 << ".model";
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(os.str().c_str(), "w"));
learner->Save(fo.get());
learner->SaveModel(fo.get());
}
if (learner->AllowLazyCheckPoint()) {
@@ -255,7 +251,7 @@ void CLITrain(const CLIParam& param) {
}
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(os.str().c_str(), "w"));
learner->Save(fo.get());
learner->SaveModel(fo.get());
}
double elapsed = dmlc::GetTime() - start;
@@ -277,7 +273,7 @@ void CLIDumpModel(const CLIParam& param) {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->SetParams(param.cfg);
learner->Load(fi.get());
learner->LoadModel(fi.get());
// dump data
std::vector<std::string> dump = learner->DumpModel(
fmap, param.dump_stats, param.dump_format);
@@ -316,7 +312,7 @@ void CLIPredict(const CLIParam& param) {
std::unique_ptr<Learner> learner(Learner::Create({}));
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->Load(fi.get());
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
LOG(INFO) << "start prediction...";