Fix CLI model IO. (#5535)
* Add test for comparing Python and CLI training result.
This commit is contained in:
@@ -138,14 +138,10 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
// constraint.
|
||||
if (name_pred == "stdout") {
|
||||
save_period = 0;
|
||||
this->cfg.emplace_back(std::make_pair("silent", "0"));
|
||||
}
|
||||
if (dsplit == 0 && rabit::IsDistributed()) {
|
||||
dsplit = 2;
|
||||
}
|
||||
if (rabit::GetRank() != 0) {
|
||||
this->cfg.emplace_back(std::make_pair("silent", "1"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -189,7 +185,7 @@ void CLITrain(const CLIParam& param) {
|
||||
if (param.model_in != "NULL") {
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->Load(fi.get());
|
||||
learner->LoadModel(fi.get());
|
||||
learner->SetParams(param.cfg);
|
||||
} else {
|
||||
learner->SetParams(param.cfg);
|
||||
@@ -229,7 +225,7 @@ void CLITrain(const CLIParam& param) {
|
||||
<< i + 1 << ".model";
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(os.str().c_str(), "w"));
|
||||
learner->Save(fo.get());
|
||||
learner->SaveModel(fo.get());
|
||||
}
|
||||
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
@@ -255,7 +251,7 @@ void CLITrain(const CLIParam& param) {
|
||||
}
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(os.str().c_str(), "w"));
|
||||
learner->Save(fo.get());
|
||||
learner->SaveModel(fo.get());
|
||||
}
|
||||
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
@@ -277,7 +273,7 @@ void CLIDumpModel(const CLIParam& param) {
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->SetParams(param.cfg);
|
||||
learner->Load(fi.get());
|
||||
learner->LoadModel(fi.get());
|
||||
// dump data
|
||||
std::vector<std::string> dump = learner->DumpModel(
|
||||
fmap, param.dump_stats, param.dump_format);
|
||||
@@ -316,7 +312,7 @@ void CLIPredict(const CLIParam& param) {
|
||||
std::unique_ptr<Learner> learner(Learner::Create({}));
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->Load(fi.get());
|
||||
learner->LoadModel(fi.get());
|
||||
learner->SetParams(param.cfg);
|
||||
|
||||
LOG(INFO) << "start prediction...";
|
||||
|
||||
Reference in New Issue
Block a user