Refactor the CLI. (#5574)

* Enable parameter validation.
* Enable JSON.
* Catch `dmlc::Error`.
* Show help message.
This commit is contained in:
Jiaming Yuan 2020-04-26 10:56:33 +08:00 committed by GitHub
parent 7d93932423
commit c90457f489
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 432 additions and 218 deletions

View File

@ -30,11 +30,11 @@ General Parameters
is displayed as warning message. If there's unexpected behaviour, please try to
increase value of verbosity.
* ``validate_parameters`` [default to false, except for Python and R interface]
* ``validate_parameters`` [default to false, except for Python, R and CLI interface]
- When set to True, XGBoost will perform validation of input parameters to check whether
a parameter is used or not. The feature is still experimental. It's expected to have
some false positives, especially when used with Scikit-Learn interface.
some false positives.
* ``nthread`` [default to maximum number of threads available if not set]

View File

@ -4,28 +4,29 @@
* \brief The command line interface program of xgboost.
* This file is not included in dynamic library.
*/
// Copyright 2014 by Contributors
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <dmlc/timer.h>
#include <xgboost/learner.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <dmlc/timer.h>
#include <iomanip>
#include <ctime>
#include <string>
#include <cstdio>
#include <cstring>
#include <vector>
#include "./common/common.h"
#include "./common/config.h"
#include "common/common.h"
#include "common/config.h"
#include "common/io.h"
#include "common/version.h"
namespace xgboost {
enum CLITask {
kTrain = 0,
kDumpModel = 1,
@ -74,6 +75,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
/*! \brief all the configurations */
std::vector<std::pair<std::string, std::string> > cfg;
static constexpr char const* const kNull = "NULL";
// declare parameters
DMLC_DECLARE_PARAMETER(CLIParam) {
// NOTE: declare everything except eval_data_paths.
@ -124,15 +127,18 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
}
// customized configure function of CLIParam
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
this->cfg = _cfg;
this->UpdateAllowUnknown(_cfg);
for (const auto& kv : _cfg) {
// Don't copy the configuration to enable parameter validation.
auto unknown_cfg = this->UpdateAllowUnknown(_cfg);
this->cfg.emplace_back("validate_parameters", "True");
for (const auto& kv : unknown_cfg) {
if (!strncmp("eval[", kv.first.c_str(), 5)) {
char evname[256];
CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1)
<< "must specify evaluation name for display";
eval_data_names.emplace_back(evname);
eval_data_paths.push_back(kv.second);
} else {
this->cfg.emplace_back(kv);
}
}
// constraint.
@ -145,221 +151,376 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
}
};
constexpr char const* const CLIParam::kNull;
DMLC_REGISTER_PARAMETER(CLIParam);
void CLITrain(const CLIParam& param) {
const double tstart_data_load = dmlc::GetTime();
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(
DMatrix::Load(
param.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2));
std::vector<std::shared_ptr<DMatrix> > deval;
std::vector<std::shared_ptr<DMatrix> > cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
deval.emplace_back(
std::shared_ptr<DMatrix>(DMatrix::Load(
param.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2)));
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param.eval_data_names;
if (param.eval_train) {
eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train");
}
// initialize the learner.
std::unique_ptr<Learner> learner(Learner::Create(cache_mats));
int version = rabit::LoadCheckPoint(learner.get());
if (version == 0) {
// initialize the model if needed.
if (param.model_in != "NULL") {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
} else {
learner->SetParams(param.cfg);
}
}
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
std::string CliHelp() {
return "Use xgboost -h for showing help information.\n";
}
// start training.
const double start = dmlc::GetTime();
for (int i = version / 2; i < param.num_round; ++i) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
learner->UpdateOneIter(i, dtrain);
if (learner->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner.get());
void CLIError(dmlc::Error const& e) {
std::cerr << "Error running xgboost:\n\n"
<< e.what() << "\n"
<< CliHelp()
<< std::endl;
}
class CLI {
CLIParam param_;
std::unique_ptr<Learner> learner_;
enum Print {
kNone,
kVersion,
kHelp
} print_info_ {kNone};
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
learner_.reset(Learner::Create(matrices));
int version = rabit::LoadCheckPoint(learner_.get());
if (version == 0) {
if (param_.model_in != CLIParam::kNull) {
this->LoadModel(param_.model_in, learner_.get());
learner_->SetParams(param_.cfg);
} else {
rabit::CheckPoint(learner.get());
learner_->SetParams(param_.cfg);
}
}
learner_->Configure();
return version;
}
void CLITrain() {
const double tstart_data_load = dmlc::GetTime();
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2));
std::vector<std::shared_ptr<DMatrix>> deval;
std::vector<std::shared_ptr<DMatrix>> cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load(
param_.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2)));
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param_.eval_data_names;
if (param_.eval_train) {
eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train");
}
// initialize the learner.
int32_t version = this->ResetLearner(cache_mats);
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
<< " sec";
// start training.
const double start = dmlc::GetTime();
for (int i = version / 2; i < param_.num_round; ++i) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed
<< " sec elapsed";
learner_->UpdateOneIter(i, dtrain);
if (learner_->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner_.get());
} else {
rabit::CheckPoint(learner_.get());
}
version += 1;
}
CHECK_EQ(version, rabit::VersionNumber());
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (rabit::IsDistributed()) {
if (rabit::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
rabit::GetRank() == 0) {
std::ostringstream os;
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
this->SaveModel(os.str(), learner_.get());
}
if (learner_->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner_.get());
} else {
rabit::CheckPoint(learner_.get());
}
version += 1;
CHECK_EQ(version, rabit::VersionNumber());
}
CHECK_EQ(version, rabit::VersionNumber());
std::string res = learner->EvalOneIter(i, eval_datasets, eval_data_names);
if (rabit::IsDistributed()) {
if (rabit::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param.save_period != 0 &&
(i + 1) % param.save_period == 0 &&
rabit::GetRank() == 0) {
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
<< " sec";
// always save final round
if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) &&
param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) {
std::ostringstream os;
os << param.model_dir << '/'
<< std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(os.str().c_str(), "w"));
if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< param_.num_round << ".model";
} else {
os << param_.model_out;
}
this->SaveModel(os.str(), learner_.get());
}
double elapsed = dmlc::GetTime() - start;
LOG(INFO) << "update end, " << elapsed << " sec in all";
}
void CLIDumpModel() {
FeatureMap fmap;
if (param_.name_fmap != CLIParam::kNull) {
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(param_.name_fmap.c_str(), "r"));
dmlc::istream is(fs.get());
fmap.LoadText(is);
}
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for dump";
this->ResetLearner({});
// dump data
std::vector<std::string> dump =
learner_->DumpModel(fmap, param_.dump_stats, param_.dump_format);
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param_.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get());
if (param_.dump_format == "json") {
os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) {
os << "," << std::endl;
}
os << dump[i]; // Dump the previously generated JSON here
}
os << std::endl << "]" << std::endl;
} else {
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
}
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void CLIPredict() {
CHECK_NE(param_.test_path, CLIParam::kNull)
<< "Test dataset parameter test:data must be specified.";
// load data
std::shared_ptr<DMatrix> dtest(DMatrix::Load(
param_.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2));
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
this->ResetLearner({});
LOG(INFO) << "Start prediction...";
HostDeviceVector<bst_float> preds;
learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit);
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param_.name_pred.c_str(), "w"));
dmlc::ostream os(fo.get());
for (bst_float p : preds.ConstHostVector()) {
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) << p
<< '\n';
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void LoadModel(std::string const& path, Learner* learner) const {
if (common::FileExtension(path) == "json") {
auto str = common::LoadSequentialFile(path);
CHECK_GT(str.size(), 2);
CHECK_EQ(str[0], '{');
Json in{Json::Load({str.c_str(), str.size()})};
learner->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));
learner->LoadModel(fi.get());
}
}
void SaveModel(std::string const& path, Learner* learner) const {
learner->Configure();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path.c_str(), "w"));
if (common::FileExtension(path) == "json") {
Json out{Object()};
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);
fo->Write(str.c_str(), str.size());
} else {
learner->SaveModel(fo.get());
}
}
if (learner->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner.get());
} else {
rabit::CheckPoint(learner.get());
void PrintHelp() const {
std::cout << "Usage: xgboost [ -h ] [ -V ] [ config file ] [ arguments ]" << std::endl;
std::stringstream ss;
ss << R"(
Options and arguments:
-h, --help
Print this message.
-V, --version
Print XGBoost version.
arguments
Extra parameters that are not specified in config file, see below.
Config file specifies the configuration for both training and testing. Each line
containing the [attribute] = [value] configuration.
General XGBoost parameters:
https://xgboost.readthedocs.io/en/latest/parameter.html
Command line interface specfic parameters:
)";
std::string help = param_.__DOC__();
auto splited = common::Split(help, '\n');
for (auto str : splited) {
ss << " " << str << '\n';
}
version += 1;
CHECK_EQ(version, rabit::VersionNumber());
ss << R"( eval[NAME]: string, optional, default='NULL'
Path to evaluation data, with NAME as data name.
)";
ss << R"(
Example: train.conf
# General parameters
booster = gbtree
objective = reg:squarederror
eta = 1.0
gamma = 1.0
seed = 0
min_child_weight = 0
max_depth = 3
# Training arguments for CLI.
num_round = 2
save_period = 0
data = "demo/data/agaricus.txt.train?format=libsvm"
eval[test] = "demo/data/agaricus.txt.test?format=libsvm"
See demo/ directory in XGBoost for more examples.
)";
std::cout << ss.str() << std::endl;
}
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start << " sec";
// always save final round
if ((param.save_period == 0 || param.num_round % param.save_period != 0) &&
param.model_out != "NONE" &&
rabit::GetRank() == 0) {
std::ostringstream os;
if (param.model_out == "NULL") {
os << param.model_dir << '/'
<< std::setfill('0') << std::setw(4)
<< param.num_round << ".model";
} else {
os << param.model_out;
void PrintVersion() const {
auto ver = Version::String(Version::Self());
std::cout << "XGBoost: " << ver << std::endl;
}
public:
CLI(int argc, char* argv[]) {
if (argc < 2) {
this->PrintHelp();
exit(1);
}
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(os.str().c_str(), "w"));
learner->SaveModel(fo.get());
}
double elapsed = dmlc::GetTime() - start;
LOG(INFO) << "update end, " << elapsed << " sec in all";
}
void CLIDumpModel(const CLIParam& param) {
FeatureMap fmap;
if (param.name_fmap != "NULL") {
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(param.name_fmap.c_str(), "r"));
dmlc::istream is(fs.get());
fmap.LoadText(is);
}
// load model
CHECK_NE(param.model_in, "NULL")
<< "Must specify model_in for dump";
std::unique_ptr<Learner> learner(Learner::Create({}));
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->SetParams(param.cfg);
learner->LoadModel(fi.get());
// dump data
std::vector<std::string> dump = learner->DumpModel(
fmap, param.dump_stats, param.dump_format);
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get());
if (param.dump_format == "json") {
os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) os << "," << std::endl;
os << dump[i]; // Dump the previously generated JSON here
for (int i = 0; i < argc; ++i) {
std::string str {argv[i]};
if (str == "-h" || str == "--help") {
print_info_ = kHelp;
break;
} else if (str == "-V" || str == "--version") {
print_info_ = kVersion;
break;
}
}
os << std::endl << "]" << std::endl;
} else {
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
if (print_info_ != kNone) {
return;
}
rabit::Init(argc, argv);
std::string config_path = argv[1];
common::ConfigParser cp(config_path);
auto cfg = cp.Parse();
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
cfg.emplace_back(std::string(name), std::string(val));
}
}
param_.Configure(cfg);
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void CLIPredict(const CLIParam& param) {
CHECK_NE(param.test_path, "NULL")
<< "Test dataset parameter test:data must be specified.";
// load data
std::shared_ptr<DMatrix> dtest(
DMatrix::Load(
param.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2));
// load model
CHECK_NE(param.model_in, "NULL")
<< "Must specify model_in for predict";
std::unique_ptr<Learner> learner(Learner::Create({}));
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
int Run() {
switch (this->print_info_) {
case kNone:
break;
case kVersion: {
this->PrintVersion();
return 0;
}
case kHelp: {
this->PrintHelp();
return 0;
}
}
LOG(INFO) << "start prediction...";
HostDeviceVector<bst_float> preds;
learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit);
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_pred.c_str(), "w"));
dmlc::ostream os(fo.get());
for (bst_float p : preds.ConstHostVector()) {
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10)
<< p << '\n';
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
int CLIRunTask(int argc, char *argv[]) {
if (argc < 2) {
printf("Usage: <config>\n");
try {
switch (param_.task) {
case kTrain:
CLITrain();
break;
case kDumpModel:
CLIDumpModel();
break;
case kPredict:
CLIPredict();
break;
}
} catch (dmlc::Error const& e) {
xgboost::CLIError(e);
return 1;
}
return 0;
}
rabit::Init(argc, argv);
common::ConfigParser cp(argv[1]);
auto cfg = cp.Parse();
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
cfg.emplace_back(std::string(name), std::string(val));
}
~CLI() {
rabit::Finalize();
}
CLIParam param;
param.Configure(cfg);
switch (param.task) {
case kTrain: CLITrain(param); break;
case kDumpModel: CLIDumpModel(param); break;
case kPredict: CLIPredict(param); break;
}
rabit::Finalize();
return 0;
}
};
} // namespace xgboost
int main(int argc, char *argv[]) {
return xgboost::CLIRunTask(argc, argv);
try {
xgboost::CLI cli(argc, argv);
return cli.Run();
} catch (dmlc::Error const& e) {
// This captures only the initialization error.
xgboost::CLIError(e);
return 1;
}
return 0;
}

View File

@ -7,8 +7,6 @@
#ifndef XGBOOST_COMMON_CONFIG_H_
#define XGBOOST_COMMON_CONFIG_H_
#include <xgboost/logging.h>
#include <cstdio>
#include <string>
#include <fstream>
#include <istream>
@ -18,6 +16,8 @@
#include <iterator>
#include <utility>
#include "xgboost/logging.h"
namespace xgboost {
namespace common {
/*!
@ -40,10 +40,16 @@ class ConfigParser {
std::string LoadConfigFile(const std::string& path) {
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
CHECK(fin) << "Failed to open: " << path;
std::string content{std::istreambuf_iterator<char>(fin),
std::istreambuf_iterator<char>()};
return content;
CHECK(fin) << "Failed to open config file: \"" << path << "\"";
try {
std::string content{std::istreambuf_iterator<char>(fin),
std::istreambuf_iterator<char>()};
return content;
} catch (std::ios_base::failure const &e) {
LOG(FATAL) << "Failed to read config file: \"" << path << "\"\n"
<< e.what();
}
return "";
}
/*!

View File

@ -30,7 +30,6 @@
#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#include "../common/io.h"
#endif
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)

View File

@ -18,7 +18,6 @@
#include <random>
#include "xgboost/host_device_vector.h"
#include "io.h"
namespace xgboost {
namespace common {

View File

@ -114,5 +114,3 @@ XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
} // namespace obj
} // namespace xgboost

View File

@ -17,10 +17,12 @@
#include "xgboost/span.h"
#include "xgboost/json.h"
#include "../common/io.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/timer.h"
#include "../data/ellpack_page.cuh"
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"

View File

@ -5,6 +5,7 @@ import platform
import xgboost
import subprocess
import numpy
import json
class TestCLI(unittest.TestCase):
@ -27,20 +28,23 @@ data = {data_path}
eval[test] = {data_path}
'''
def test_cli_model(self):
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
project_root = os.path.normpath(
os.path.join(curdir, os.path.pardir, os.path.pardir))
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=project_root)
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
project_root = os.path.normpath(
os.path.join(curdir, os.path.pardir, os.path.pardir))
def get_exe(self):
if platform.system() == 'Windows':
exe = 'xgboost.exe'
else:
exe = 'xgboost'
exe = os.path.join(project_root, exe)
exe = os.path.join(self.project_root, exe)
assert os.path.exists(exe)
return exe
def test_cli_model(self):
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root)
exe = self.get_exe()
seed = 1994
with tempfile.TemporaryDirectory() as tmpdir:
@ -102,3 +106,48 @@ eval[test] = {data_path}
py_model_bin = fd.read()
assert hash(cli_model_bin) == hash(py_model_bin)
def test_cli_help(self):
exe = self.get_exe()
completed = subprocess.run([exe], stdout=subprocess.PIPE)
error_msg = completed.stdout.decode('utf-8')
ret = completed.returncode
assert ret == 1
assert error_msg.find('Usage') != -1
assert error_msg.find('eval[NAME]') != -1
completed = subprocess.run([exe, '-V'], stdout=subprocess.PIPE)
msg = completed.stdout.decode('utf-8')
assert msg.find('XGBoost') != -1
v = xgboost.__version__
if v.find('SNAPSHOT') != -1:
assert msg.split(':')[1].strip() == v.split('-')[0]
else:
assert msg.split(':')[1].strip() == v
def test_cli_model_json(self):
exe = self.get_exe()
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root)
seed = 1994
with tempfile.TemporaryDirectory() as tmpdir:
model_out_cli = os.path.join(
tmpdir, 'test_load_cli_model-cli.json')
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
train_conf = self.template.format(data_path=data_path,
seed=seed,
task='train',
model_in='NULL',
model_out=model_out_cli,
test_path='NULL',
name_pred='NULL')
with open(config_path, 'w') as fd:
fd.write(train_conf)
subprocess.run([exe, config_path])
with open(model_out_cli, 'r') as fd:
model = json.load(fd)
assert model['learner']['gradient_booster']['name'] == 'gbtree'