Refactor the CLI. (#5574)

* Enable parameter validation.
* Enable JSON.
* Catch `dmlc::Error`.
* Show help message.
This commit is contained in:
Jiaming Yuan 2020-04-26 10:56:33 +08:00 committed by GitHub
parent 7d93932423
commit c90457f489
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 432 additions and 218 deletions

View File

@ -30,11 +30,11 @@ General Parameters
is displayed as warning message. If there's unexpected behaviour, please try to is displayed as warning message. If there's unexpected behaviour, please try to
increase value of verbosity. increase value of verbosity.
* ``validate_parameters`` [default to false, except for Python and R interface] * ``validate_parameters`` [default to false, except for Python, R and CLI interface]
- When set to True, XGBoost will perform validation of input parameters to check whether - When set to True, XGBoost will perform validation of input parameters to check whether
a parameter is used or not. The feature is still experimental. It's expected to have a parameter is used or not. The feature is still experimental. It's expected to have
some false positives, especially when used with Scikit-Learn interface. some false positives.
* ``nthread`` [default to maximum number of threads available if not set] * ``nthread`` [default to maximum number of threads available if not set]

View File

@ -4,28 +4,29 @@
* \brief The command line interface program of xgboost. * \brief The command line interface program of xgboost.
* This file is not included in dynamic library. * This file is not included in dynamic library.
*/ */
// Copyright 2014 by Contributors
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
#include <dmlc/timer.h>
#include <xgboost/learner.h> #include <xgboost/learner.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <xgboost/json.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/parameter.h> #include <xgboost/parameter.h>
#include <dmlc/timer.h>
#include <iomanip> #include <iomanip>
#include <ctime> #include <ctime>
#include <string> #include <string>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include "./common/common.h" #include "common/common.h"
#include "./common/config.h" #include "common/config.h"
#include "common/io.h"
#include "common/version.h"
namespace xgboost { namespace xgboost {
enum CLITask { enum CLITask {
kTrain = 0, kTrain = 0,
kDumpModel = 1, kDumpModel = 1,
@ -74,6 +75,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
/*! \brief all the configurations */ /*! \brief all the configurations */
std::vector<std::pair<std::string, std::string> > cfg; std::vector<std::pair<std::string, std::string> > cfg;
static constexpr char const* const kNull = "NULL";
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(CLIParam) { DMLC_DECLARE_PARAMETER(CLIParam) {
// NOTE: declare everything except eval_data_paths. // NOTE: declare everything except eval_data_paths.
@ -124,15 +127,18 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
} }
// customized configure function of CLIParam // customized configure function of CLIParam
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) { inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
this->cfg = _cfg; // Don't copy the configuration to enable parameter validation.
this->UpdateAllowUnknown(_cfg); auto unknown_cfg = this->UpdateAllowUnknown(_cfg);
for (const auto& kv : _cfg) { this->cfg.emplace_back("validate_parameters", "True");
for (const auto& kv : unknown_cfg) {
if (!strncmp("eval[", kv.first.c_str(), 5)) { if (!strncmp("eval[", kv.first.c_str(), 5)) {
char evname[256]; char evname[256];
CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1) CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1)
<< "must specify evaluation name for display"; << "must specify evaluation name for display";
eval_data_names.emplace_back(evname); eval_data_names.emplace_back(evname);
eval_data_paths.push_back(kv.second); eval_data_paths.push_back(kv.second);
} else {
this->cfg.emplace_back(kv);
} }
} }
// constraint. // constraint.
@ -145,70 +151,95 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
} }
}; };
constexpr char const* const CLIParam::kNull;
DMLC_REGISTER_PARAMETER(CLIParam); DMLC_REGISTER_PARAMETER(CLIParam);
void CLITrain(const CLIParam& param) { std::string CliHelp() {
return "Use xgboost -h for showing help information.\n";
}
void CLIError(dmlc::Error const& e) {
std::cerr << "Error running xgboost:\n\n"
<< e.what() << "\n"
<< CliHelp()
<< std::endl;
}
class CLI {
CLIParam param_;
std::unique_ptr<Learner> learner_;
enum Print {
kNone,
kVersion,
kHelp
} print_info_ {kNone};
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
learner_.reset(Learner::Create(matrices));
int version = rabit::LoadCheckPoint(learner_.get());
if (version == 0) {
if (param_.model_in != CLIParam::kNull) {
this->LoadModel(param_.model_in, learner_.get());
learner_->SetParams(param_.cfg);
} else {
learner_->SetParams(param_.cfg);
}
}
learner_->Configure();
return version;
}
void CLITrain() {
const double tstart_data_load = dmlc::GetTime(); const double tstart_data_load = dmlc::GetTime();
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName(); std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank(); LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
} }
// load in data. // load in data.
std::shared_ptr<DMatrix> dtrain( std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
DMatrix::Load( param_.train_path,
param.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2)); param_.dsplit == 2));
std::vector<std::shared_ptr<DMatrix>> deval; std::vector<std::shared_ptr<DMatrix>> deval;
std::vector<std::shared_ptr<DMatrix>> cache_mats; std::vector<std::shared_ptr<DMatrix>> cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets; std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain); cache_mats.push_back(dtrain);
for (size_t i = 0; i < param.eval_data_names.size(); ++i) { for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
deval.emplace_back( deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load(
std::shared_ptr<DMatrix>(DMatrix::Load( param_.eval_data_paths[i],
param.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2))); param_.dsplit == 2)));
eval_datasets.push_back(deval.back()); eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back()); cache_mats.push_back(deval.back());
} }
std::vector<std::string> eval_data_names = param.eval_data_names; std::vector<std::string> eval_data_names = param_.eval_data_names;
if (param.eval_train) { if (param_.eval_train) {
eval_datasets.push_back(dtrain); eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train"); eval_data_names.emplace_back("train");
} }
// initialize the learner. // initialize the learner.
std::unique_ptr<Learner> learner(Learner::Create(cache_mats)); int32_t version = this->ResetLearner(cache_mats);
int version = rabit::LoadCheckPoint(learner.get()); LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
if (version == 0) { << " sec";
// initialize the model if needed.
if (param.model_in != "NULL") {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
} else {
learner->SetParams(param.cfg);
}
}
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
// start training. // start training.
const double start = dmlc::GetTime(); const double start = dmlc::GetTime();
for (int i = version / 2; i < param.num_round; ++i) { for (int i = version / 2; i < param_.num_round; ++i) {
double elapsed = dmlc::GetTime() - start; double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) { if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed"; LOG(INFO) << "boosting round " << i << ", " << elapsed
learner->UpdateOneIter(i, dtrain); << " sec elapsed";
if (learner->AllowLazyCheckPoint()) { learner_->UpdateOneIter(i, dtrain);
rabit::LazyCheckPoint(learner.get()); if (learner_->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner_.get());
} else { } else {
rabit::CheckPoint(learner.get()); rabit::CheckPoint(learner_.get());
} }
version += 1; version += 1;
} }
CHECK_EQ(version, rabit::VersionNumber()); CHECK_EQ(version, rabit::VersionNumber());
std::string res = learner->EvalOneIter(i, eval_datasets, eval_data_names); std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
if (rabit::GetRank() == 0) { if (rabit::GetRank() == 0) {
LOG(TRACKER) << res; LOG(TRACKER) << res;
@ -216,74 +247,66 @@ void CLITrain(const CLIParam& param) {
} else { } else {
LOG(CONSOLE) << res; LOG(CONSOLE) << res;
} }
if (param.save_period != 0 && if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
(i + 1) % param.save_period == 0 &&
rabit::GetRank() == 0) { rabit::GetRank() == 0) {
std::ostringstream os; std::ostringstream os;
os << param.model_dir << '/' os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< std::setfill('0') << std::setw(4)
<< i + 1 << ".model"; << i + 1 << ".model";
std::unique_ptr<dmlc::Stream> fo( this->SaveModel(os.str(), learner_.get());
dmlc::Stream::Create(os.str().c_str(), "w"));
learner->SaveModel(fo.get());
} }
if (learner->AllowLazyCheckPoint()) { if (learner_->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner.get()); rabit::LazyCheckPoint(learner_.get());
} else { } else {
rabit::CheckPoint(learner.get()); rabit::CheckPoint(learner_.get());
} }
version += 1; version += 1;
CHECK_EQ(version, rabit::VersionNumber()); CHECK_EQ(version, rabit::VersionNumber());
} }
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start << " sec"; LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
<< " sec";
// always save final round // always save final round
if ((param.save_period == 0 || param.num_round % param.save_period != 0) && if ((param_.save_period == 0 ||
param.model_out != "NONE" && param_.num_round % param_.save_period != 0) &&
rabit::GetRank() == 0) { param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) {
std::ostringstream os; std::ostringstream os;
if (param.model_out == "NULL") { if (param_.model_out == CLIParam::kNull) {
os << param.model_dir << '/' os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< std::setfill('0') << std::setw(4) << param_.num_round << ".model";
<< param.num_round << ".model";
} else { } else {
os << param.model_out; os << param_.model_out;
} }
std::unique_ptr<dmlc::Stream> fo( this->SaveModel(os.str(), learner_.get());
dmlc::Stream::Create(os.str().c_str(), "w"));
learner->SaveModel(fo.get());
} }
double elapsed = dmlc::GetTime() - start; double elapsed = dmlc::GetTime() - start;
LOG(INFO) << "update end, " << elapsed << " sec in all"; LOG(INFO) << "update end, " << elapsed << " sec in all";
} }
void CLIDumpModel(const CLIParam& param) { void CLIDumpModel() {
FeatureMap fmap; FeatureMap fmap;
if (param.name_fmap != "NULL") { if (param_.name_fmap != CLIParam::kNull) {
std::unique_ptr<dmlc::Stream> fs( std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(param.name_fmap.c_str(), "r")); dmlc::Stream::Create(param_.name_fmap.c_str(), "r"));
dmlc::istream is(fs.get()); dmlc::istream is(fs.get());
fmap.LoadText(is); fmap.LoadText(is);
} }
// load model // load model
CHECK_NE(param.model_in, "NULL") CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for dump";
<< "Must specify model_in for dump"; this->ResetLearner({});
std::unique_ptr<Learner> learner(Learner::Create({}));
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->SetParams(param.cfg);
learner->LoadModel(fi.get());
// dump data // dump data
std::vector<std::string> dump = learner->DumpModel( std::vector<std::string> dump =
fmap, param.dump_stats, param.dump_format); learner_->DumpModel(fmap, param_.dump_stats, param_.dump_format);
std::unique_ptr<dmlc::Stream> fo( std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_dump.c_str(), "w")); dmlc::Stream::Create(param_.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get()); dmlc::ostream os(fo.get());
if (param.dump_format == "json") { if (param_.dump_format == "json") {
os << "[" << std::endl; os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) { for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) os << "," << std::endl; if (i != 0) {
os << "," << std::endl;
}
os << dump[i]; // Dump the previously generated JSON here os << dump[i]; // Dump the previously generated JSON here
} }
os << std::endl << "]" << std::endl; os << std::endl << "]" << std::endl;
@ -297,48 +320,148 @@ void CLIDumpModel(const CLIParam& param) {
os.set_stream(nullptr); os.set_stream(nullptr);
} }
void CLIPredict(const CLIParam& param) { void CLIPredict() {
CHECK_NE(param.test_path, "NULL") CHECK_NE(param_.test_path, CLIParam::kNull)
<< "Test dataset parameter test:data must be specified."; << "Test dataset parameter test:data must be specified.";
// load data // load data
std::shared_ptr<DMatrix> dtest( std::shared_ptr<DMatrix> dtest(DMatrix::Load(
DMatrix::Load( param_.test_path,
param.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2)); param_.dsplit == 2));
// load model // load model
CHECK_NE(param.model_in, "NULL") CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
<< "Must specify model_in for predict"; this->ResetLearner({});
std::unique_ptr<Learner> learner(Learner::Create({}));
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
LOG(INFO) << "start prediction..."; LOG(INFO) << "Start prediction...";
HostDeviceVector<bst_float> preds; HostDeviceVector<bst_float> preds;
learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit); learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit);
LOG(CONSOLE) << "writing prediction to " << param.name_pred; LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
std::unique_ptr<dmlc::Stream> fo( std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_pred.c_str(), "w")); dmlc::Stream::Create(param_.name_pred.c_str(), "w"));
dmlc::ostream os(fo.get()); dmlc::ostream os(fo.get());
for (bst_float p : preds.ConstHostVector()) { for (bst_float p : preds.ConstHostVector()) {
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) << p
<< p << '\n'; << '\n';
} }
// force flush before fo destruct. // force flush before fo destruct.
os.set_stream(nullptr); os.set_stream(nullptr);
} }
int CLIRunTask(int argc, char *argv[]) { void LoadModel(std::string const& path, Learner* learner) const {
if (argc < 2) { if (common::FileExtension(path) == "json") {
printf("Usage: <config>\n"); auto str = common::LoadSequentialFile(path);
return 0; CHECK_GT(str.size(), 2);
CHECK_EQ(str[0], '{');
Json in{Json::Load({str.c_str(), str.size()})};
learner->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));
learner->LoadModel(fi.get());
}
} }
rabit::Init(argc, argv);
common::ConfigParser cp(argv[1]); void SaveModel(std::string const& path, Learner* learner) const {
learner->Configure();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path.c_str(), "w"));
if (common::FileExtension(path) == "json") {
Json out{Object()};
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);
fo->Write(str.c_str(), str.size());
} else {
learner->SaveModel(fo.get());
}
}
void PrintHelp() const {
std::cout << "Usage: xgboost [ -h ] [ -V ] [ config file ] [ arguments ]" << std::endl;
std::stringstream ss;
ss << R"(
Options and arguments:
-h, --help
Print this message.
-V, --version
Print XGBoost version.
arguments
Extra parameters that are not specified in config file, see below.
Config file specifies the configuration for both training and testing. Each line
containing the [attribute] = [value] configuration.
General XGBoost parameters:
https://xgboost.readthedocs.io/en/latest/parameter.html
Command line interface specfic parameters:
)";
std::string help = param_.__DOC__();
auto splited = common::Split(help, '\n');
for (auto str : splited) {
ss << " " << str << '\n';
}
ss << R"( eval[NAME]: string, optional, default='NULL'
Path to evaluation data, with NAME as data name.
)";
ss << R"(
Example: train.conf
# General parameters
booster = gbtree
objective = reg:squarederror
eta = 1.0
gamma = 1.0
seed = 0
min_child_weight = 0
max_depth = 3
# Training arguments for CLI.
num_round = 2
save_period = 0
data = "demo/data/agaricus.txt.train?format=libsvm"
eval[test] = "demo/data/agaricus.txt.test?format=libsvm"
See demo/ directory in XGBoost for more examples.
)";
std::cout << ss.str() << std::endl;
}
void PrintVersion() const {
auto ver = Version::String(Version::Self());
std::cout << "XGBoost: " << ver << std::endl;
}
public:
CLI(int argc, char* argv[]) {
if (argc < 2) {
this->PrintHelp();
exit(1);
}
for (int i = 0; i < argc; ++i) {
std::string str {argv[i]};
if (str == "-h" || str == "--help") {
print_info_ = kHelp;
break;
} else if (str == "-V" || str == "--version") {
print_info_ = kVersion;
break;
}
}
if (print_info_ != kNone) {
return;
}
rabit::Init(argc, argv);
std::string config_path = argv[1];
common::ConfigParser cp(config_path);
auto cfg = cp.Parse(); auto cfg = cp.Parse();
for (int i = 2; i < argc; ++i) { for (int i = 2; i < argc; ++i) {
@ -347,19 +470,57 @@ int CLIRunTask(int argc, char *argv[]) {
cfg.emplace_back(std::string(name), std::string(val)); cfg.emplace_back(std::string(name), std::string(val));
} }
} }
CLIParam param;
param.Configure(cfg);
switch (param.task) { param_.Configure(cfg);
case kTrain: CLITrain(param); break;
case kDumpModel: CLIDumpModel(param); break;
case kPredict: CLIPredict(param); break;
} }
rabit::Finalize();
int Run() {
switch (this->print_info_) {
case kNone:
break;
case kVersion: {
this->PrintVersion();
return 0; return 0;
} }
case kHelp: {
this->PrintHelp();
return 0;
}
}
try {
switch (param_.task) {
case kTrain:
CLITrain();
break;
case kDumpModel:
CLIDumpModel();
break;
case kPredict:
CLIPredict();
break;
}
} catch (dmlc::Error const& e) {
xgboost::CLIError(e);
return 1;
}
return 0;
}
~CLI() {
rabit::Finalize();
}
};
} // namespace xgboost } // namespace xgboost
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
return xgboost::CLIRunTask(argc, argv); try {
xgboost::CLI cli(argc, argv);
return cli.Run();
} catch (dmlc::Error const& e) {
// This captures only the initialization error.
xgboost::CLIError(e);
return 1;
}
return 0;
} }

View File

@ -7,8 +7,6 @@
#ifndef XGBOOST_COMMON_CONFIG_H_ #ifndef XGBOOST_COMMON_CONFIG_H_
#define XGBOOST_COMMON_CONFIG_H_ #define XGBOOST_COMMON_CONFIG_H_
#include <xgboost/logging.h>
#include <cstdio>
#include <string> #include <string>
#include <fstream> #include <fstream>
#include <istream> #include <istream>
@ -18,6 +16,8 @@
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include "xgboost/logging.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! /*!
@ -40,10 +40,16 @@ class ConfigParser {
std::string LoadConfigFile(const std::string& path) { std::string LoadConfigFile(const std::string& path) {
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary); std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
CHECK(fin) << "Failed to open: " << path; CHECK(fin) << "Failed to open config file: \"" << path << "\"";
try {
std::string content{std::istreambuf_iterator<char>(fin), std::string content{std::istreambuf_iterator<char>(fin),
std::istreambuf_iterator<char>()}; std::istreambuf_iterator<char>()};
return content; return content;
} catch (std::ios_base::failure const &e) {
LOG(FATAL) << "Failed to read config file: \"" << path << "\"\n"
<< e.what();
}
return "";
} }
/*! /*!

View File

@ -30,7 +30,6 @@
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
#include "../common/io.h"
#endif #endif
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)

View File

@ -18,7 +18,6 @@
#include <random> #include <random>
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "io.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {

View File

@ -114,5 +114,3 @@ XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -17,10 +17,12 @@
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "../common/io.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../data/ellpack_page.cuh" #include "../data/ellpack_page.cuh"
#include "param.h" #include "param.h"
#include "updater_gpu_common.cuh" #include "updater_gpu_common.cuh"
#include "constraints.cuh" #include "constraints.cuh"

View File

@ -5,6 +5,7 @@ import platform
import xgboost import xgboost
import subprocess import subprocess
import numpy import numpy
import json
class TestCLI(unittest.TestCase): class TestCLI(unittest.TestCase):
@ -27,20 +28,23 @@ data = {data_path}
eval[test] = {data_path} eval[test] = {data_path}
''' '''
def test_cli_model(self):
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__))) curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
project_root = os.path.normpath( project_root = os.path.normpath(
os.path.join(curdir, os.path.pardir, os.path.pardir)) os.path.join(curdir, os.path.pardir, os.path.pardir))
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=project_root)
def get_exe(self):
if platform.system() == 'Windows': if platform.system() == 'Windows':
exe = 'xgboost.exe' exe = 'xgboost.exe'
else: else:
exe = 'xgboost' exe = 'xgboost'
exe = os.path.join(project_root, exe) exe = os.path.join(self.project_root, exe)
assert os.path.exists(exe) assert os.path.exists(exe)
return exe
def test_cli_model(self):
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root)
exe = self.get_exe()
seed = 1994 seed = 1994
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
@ -102,3 +106,48 @@ eval[test] = {data_path}
py_model_bin = fd.read() py_model_bin = fd.read()
assert hash(cli_model_bin) == hash(py_model_bin) assert hash(cli_model_bin) == hash(py_model_bin)
def test_cli_help(self):
exe = self.get_exe()
completed = subprocess.run([exe], stdout=subprocess.PIPE)
error_msg = completed.stdout.decode('utf-8')
ret = completed.returncode
assert ret == 1
assert error_msg.find('Usage') != -1
assert error_msg.find('eval[NAME]') != -1
completed = subprocess.run([exe, '-V'], stdout=subprocess.PIPE)
msg = completed.stdout.decode('utf-8')
assert msg.find('XGBoost') != -1
v = xgboost.__version__
if v.find('SNAPSHOT') != -1:
assert msg.split(':')[1].strip() == v.split('-')[0]
else:
assert msg.split(':')[1].strip() == v
def test_cli_model_json(self):
exe = self.get_exe()
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root)
seed = 1994
with tempfile.TemporaryDirectory() as tmpdir:
model_out_cli = os.path.join(
tmpdir, 'test_load_cli_model-cli.json')
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
train_conf = self.template.format(data_path=data_path,
seed=seed,
task='train',
model_in='NULL',
model_out=model_out_cli,
test_path='NULL',
name_pred='NULL')
with open(config_path, 'w') as fd:
fd.write(train_conf)
subprocess.run([exe, config_path])
with open(model_out_cli, 'r') as fd:
model = json.load(fd)
assert model['learner']['gradient_booster']['name'] == 'gbtree'