Refactor the CLI. (#5574)

* Enable parameter validation.
* Enable JSON.
* Catch `dmlc::Error`.
* Show help message.
This commit is contained in:
Jiaming Yuan 2020-04-26 10:56:33 +08:00 committed by GitHub
parent 7d93932423
commit c90457f489
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 432 additions and 218 deletions

View File

@ -30,11 +30,11 @@ General Parameters
is displayed as warning message. If there's unexpected behaviour, please try to is displayed as warning message. If there's unexpected behaviour, please try to
increase value of verbosity. increase value of verbosity.
* ``validate_parameters`` [default to false, except for Python and R interface] * ``validate_parameters`` [default to false, except for Python, R and CLI interface]
- When set to True, XGBoost will perform validation of input parameters to check whether - When set to True, XGBoost will perform validation of input parameters to check whether
a parameter is used or not. The feature is still experimental. It's expected to have a parameter is used or not. The feature is still experimental. It's expected to have
some false positives, especially when used with Scikit-Learn interface. some false positives.
* ``nthread`` [default to maximum number of threads available if not set] * ``nthread`` [default to maximum number of threads available if not set]

View File

@ -4,28 +4,29 @@
* \brief The command line interface program of xgboost. * \brief The command line interface program of xgboost.
* This file is not included in dynamic library. * This file is not included in dynamic library.
*/ */
// Copyright 2014 by Contributors
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
#include <dmlc/timer.h>
#include <xgboost/learner.h> #include <xgboost/learner.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <xgboost/json.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/parameter.h> #include <xgboost/parameter.h>
#include <dmlc/timer.h>
#include <iomanip> #include <iomanip>
#include <ctime> #include <ctime>
#include <string> #include <string>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include "./common/common.h" #include "common/common.h"
#include "./common/config.h" #include "common/config.h"
#include "common/io.h"
#include "common/version.h"
namespace xgboost { namespace xgboost {
enum CLITask { enum CLITask {
kTrain = 0, kTrain = 0,
kDumpModel = 1, kDumpModel = 1,
@ -74,6 +75,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
/*! \brief all the configurations */ /*! \brief all the configurations */
std::vector<std::pair<std::string, std::string> > cfg; std::vector<std::pair<std::string, std::string> > cfg;
static constexpr char const* const kNull = "NULL";
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(CLIParam) { DMLC_DECLARE_PARAMETER(CLIParam) {
// NOTE: declare everything except eval_data_paths. // NOTE: declare everything except eval_data_paths.
@ -124,15 +127,18 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
} }
// customized configure function of CLIParam // customized configure function of CLIParam
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) { inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
this->cfg = _cfg; // Don't copy the configuration to enable parameter validation.
this->UpdateAllowUnknown(_cfg); auto unknown_cfg = this->UpdateAllowUnknown(_cfg);
for (const auto& kv : _cfg) { this->cfg.emplace_back("validate_parameters", "True");
for (const auto& kv : unknown_cfg) {
if (!strncmp("eval[", kv.first.c_str(), 5)) { if (!strncmp("eval[", kv.first.c_str(), 5)) {
char evname[256]; char evname[256];
CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1) CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1)
<< "must specify evaluation name for display"; << "must specify evaluation name for display";
eval_data_names.emplace_back(evname); eval_data_names.emplace_back(evname);
eval_data_paths.push_back(kv.second); eval_data_paths.push_back(kv.second);
} else {
this->cfg.emplace_back(kv);
} }
} }
// constraint. // constraint.
@ -145,221 +151,376 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
} }
}; };
constexpr char const* const CLIParam::kNull;
DMLC_REGISTER_PARAMETER(CLIParam); DMLC_REGISTER_PARAMETER(CLIParam);
void CLITrain(const CLIParam& param) { std::string CliHelp() {
const double tstart_data_load = dmlc::GetTime(); return "Use xgboost -h for showing help information.\n";
if (rabit::IsDistributed()) { }
std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(
DMatrix::Load(
param.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2));
std::vector<std::shared_ptr<DMatrix> > deval;
std::vector<std::shared_ptr<DMatrix> > cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
deval.emplace_back(
std::shared_ptr<DMatrix>(DMatrix::Load(
param.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param.dsplit == 2)));
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param.eval_data_names;
if (param.eval_train) {
eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train");
}
// initialize the learner.
std::unique_ptr<Learner> learner(Learner::Create(cache_mats));
int version = rabit::LoadCheckPoint(learner.get());
if (version == 0) {
// initialize the model if needed.
if (param.model_in != "NULL") {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
} else {
learner->SetParams(param.cfg);
}
}
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
// start training. void CLIError(dmlc::Error const& e) {
const double start = dmlc::GetTime(); std::cerr << "Error running xgboost:\n\n"
for (int i = version / 2; i < param.num_round; ++i) { << e.what() << "\n"
double elapsed = dmlc::GetTime() - start; << CliHelp()
if (version % 2 == 0) { << std::endl;
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed"; }
learner->UpdateOneIter(i, dtrain);
if (learner->AllowLazyCheckPoint()) { class CLI {
rabit::LazyCheckPoint(learner.get()); CLIParam param_;
std::unique_ptr<Learner> learner_;
enum Print {
kNone,
kVersion,
kHelp
} print_info_ {kNone};
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
learner_.reset(Learner::Create(matrices));
int version = rabit::LoadCheckPoint(learner_.get());
if (version == 0) {
if (param_.model_in != CLIParam::kNull) {
this->LoadModel(param_.model_in, learner_.get());
learner_->SetParams(param_.cfg);
} else { } else {
rabit::CheckPoint(learner.get()); learner_->SetParams(param_.cfg);
}
}
learner_->Configure();
return version;
}
void CLITrain() {
const double tstart_data_load = dmlc::GetTime();
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2));
std::vector<std::shared_ptr<DMatrix>> deval;
std::vector<std::shared_ptr<DMatrix>> cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load(
param_.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2)));
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param_.eval_data_names;
if (param_.eval_train) {
eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train");
}
// initialize the learner.
int32_t version = this->ResetLearner(cache_mats);
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
<< " sec";
// start training.
const double start = dmlc::GetTime();
for (int i = version / 2; i < param_.num_round; ++i) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed
<< " sec elapsed";
learner_->UpdateOneIter(i, dtrain);
if (learner_->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner_.get());
} else {
rabit::CheckPoint(learner_.get());
}
version += 1;
}
CHECK_EQ(version, rabit::VersionNumber());
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (rabit::IsDistributed()) {
if (rabit::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
rabit::GetRank() == 0) {
std::ostringstream os;
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
this->SaveModel(os.str(), learner_.get());
}
if (learner_->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner_.get());
} else {
rabit::CheckPoint(learner_.get());
} }
version += 1; version += 1;
CHECK_EQ(version, rabit::VersionNumber());
} }
CHECK_EQ(version, rabit::VersionNumber()); LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
std::string res = learner->EvalOneIter(i, eval_datasets, eval_data_names); << " sec";
if (rabit::IsDistributed()) { // always save final round
if (rabit::GetRank() == 0) { if ((param_.save_period == 0 ||
LOG(TRACKER) << res; param_.num_round % param_.save_period != 0) &&
} param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) {
} else {
LOG(CONSOLE) << res;
}
if (param.save_period != 0 &&
(i + 1) % param.save_period == 0 &&
rabit::GetRank() == 0) {
std::ostringstream os; std::ostringstream os;
os << param.model_dir << '/' if (param_.model_out == CLIParam::kNull) {
<< std::setfill('0') << std::setw(4) os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model"; << param_.num_round << ".model";
std::unique_ptr<dmlc::Stream> fo( } else {
dmlc::Stream::Create(os.str().c_str(), "w")); os << param_.model_out;
}
this->SaveModel(os.str(), learner_.get());
}
double elapsed = dmlc::GetTime() - start;
LOG(INFO) << "update end, " << elapsed << " sec in all";
}
void CLIDumpModel() {
FeatureMap fmap;
if (param_.name_fmap != CLIParam::kNull) {
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(param_.name_fmap.c_str(), "r"));
dmlc::istream is(fs.get());
fmap.LoadText(is);
}
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for dump";
this->ResetLearner({});
// dump data
std::vector<std::string> dump =
learner_->DumpModel(fmap, param_.dump_stats, param_.dump_format);
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param_.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get());
if (param_.dump_format == "json") {
os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) {
os << "," << std::endl;
}
os << dump[i]; // Dump the previously generated JSON here
}
os << std::endl << "]" << std::endl;
} else {
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
}
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void CLIPredict() {
CHECK_NE(param_.test_path, CLIParam::kNull)
<< "Test dataset parameter test:data must be specified.";
// load data
std::shared_ptr<DMatrix> dtest(DMatrix::Load(
param_.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2));
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
this->ResetLearner({});
LOG(INFO) << "Start prediction...";
HostDeviceVector<bst_float> preds;
learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit);
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param_.name_pred.c_str(), "w"));
dmlc::ostream os(fo.get());
for (bst_float p : preds.ConstHostVector()) {
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) << p
<< '\n';
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void LoadModel(std::string const& path, Learner* learner) const {
if (common::FileExtension(path) == "json") {
auto str = common::LoadSequentialFile(path);
CHECK_GT(str.size(), 2);
CHECK_EQ(str[0], '{');
Json in{Json::Load({str.c_str(), str.size()})};
learner->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));
learner->LoadModel(fi.get());
}
}
void SaveModel(std::string const& path, Learner* learner) const {
learner->Configure();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path.c_str(), "w"));
if (common::FileExtension(path) == "json") {
Json out{Object()};
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);
fo->Write(str.c_str(), str.size());
} else {
learner->SaveModel(fo.get()); learner->SaveModel(fo.get());
} }
}
if (learner->AllowLazyCheckPoint()) { void PrintHelp() const {
rabit::LazyCheckPoint(learner.get()); std::cout << "Usage: xgboost [ -h ] [ -V ] [ config file ] [ arguments ]" << std::endl;
} else { std::stringstream ss;
rabit::CheckPoint(learner.get()); ss << R"(
Options and arguments:
-h, --help
Print this message.
-V, --version
Print XGBoost version.
arguments
Extra parameters that are not specified in config file, see below.
Config file specifies the configuration for both training and testing. Each line
containing the [attribute] = [value] configuration.
General XGBoost parameters:
https://xgboost.readthedocs.io/en/latest/parameter.html
Command line interface specfic parameters:
)";
std::string help = param_.__DOC__();
auto splited = common::Split(help, '\n');
for (auto str : splited) {
ss << " " << str << '\n';
} }
version += 1; ss << R"( eval[NAME]: string, optional, default='NULL'
CHECK_EQ(version, rabit::VersionNumber()); Path to evaluation data, with NAME as data name.
)";
ss << R"(
Example: train.conf
# General parameters
booster = gbtree
objective = reg:squarederror
eta = 1.0
gamma = 1.0
seed = 0
min_child_weight = 0
max_depth = 3
# Training arguments for CLI.
num_round = 2
save_period = 0
data = "demo/data/agaricus.txt.train?format=libsvm"
eval[test] = "demo/data/agaricus.txt.test?format=libsvm"
See demo/ directory in XGBoost for more examples.
)";
std::cout << ss.str() << std::endl;
} }
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start << " sec";
// always save final round void PrintVersion() const {
if ((param.save_period == 0 || param.num_round % param.save_period != 0) && auto ver = Version::String(Version::Self());
param.model_out != "NONE" && std::cout << "XGBoost: " << ver << std::endl;
rabit::GetRank() == 0) { }
std::ostringstream os;
if (param.model_out == "NULL") { public:
os << param.model_dir << '/' CLI(int argc, char* argv[]) {
<< std::setfill('0') << std::setw(4) if (argc < 2) {
<< param.num_round << ".model"; this->PrintHelp();
} else { exit(1);
os << param.model_out;
} }
std::unique_ptr<dmlc::Stream> fo( for (int i = 0; i < argc; ++i) {
dmlc::Stream::Create(os.str().c_str(), "w")); std::string str {argv[i]};
learner->SaveModel(fo.get()); if (str == "-h" || str == "--help") {
} print_info_ = kHelp;
break;
double elapsed = dmlc::GetTime() - start; } else if (str == "-V" || str == "--version") {
LOG(INFO) << "update end, " << elapsed << " sec in all"; print_info_ = kVersion;
} break;
}
void CLIDumpModel(const CLIParam& param) {
FeatureMap fmap;
if (param.name_fmap != "NULL") {
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(param.name_fmap.c_str(), "r"));
dmlc::istream is(fs.get());
fmap.LoadText(is);
}
// load model
CHECK_NE(param.model_in, "NULL")
<< "Must specify model_in for dump";
std::unique_ptr<Learner> learner(Learner::Create({}));
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->SetParams(param.cfg);
learner->LoadModel(fi.get());
// dump data
std::vector<std::string> dump = learner->DumpModel(
fmap, param.dump_stats, param.dump_format);
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get());
if (param.dump_format == "json") {
os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) os << "," << std::endl;
os << dump[i]; // Dump the previously generated JSON here
} }
os << std::endl << "]" << std::endl; if (print_info_ != kNone) {
} else { return;
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
} }
rabit::Init(argc, argv);
std::string config_path = argv[1];
common::ConfigParser cp(config_path);
auto cfg = cp.Parse();
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
cfg.emplace_back(std::string(name), std::string(val));
}
}
param_.Configure(cfg);
} }
// force flush before fo destruct.
os.set_stream(nullptr);
}
void CLIPredict(const CLIParam& param) { int Run() {
CHECK_NE(param.test_path, "NULL") switch (this->print_info_) {
<< "Test dataset parameter test:data must be specified."; case kNone:
// load data break;
std::shared_ptr<DMatrix> dtest( case kVersion: {
DMatrix::Load( this->PrintVersion();
param.test_path, return 0;
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), }
param.dsplit == 2)); case kHelp: {
// load model this->PrintHelp();
CHECK_NE(param.model_in, "NULL") return 0;
<< "Must specify model_in for predict"; }
std::unique_ptr<Learner> learner(Learner::Create({})); }
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(param.model_in.c_str(), "r"));
learner->LoadModel(fi.get());
learner->SetParams(param.cfg);
LOG(INFO) << "start prediction..."; try {
HostDeviceVector<bst_float> preds; switch (param_.task) {
learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit); case kTrain:
LOG(CONSOLE) << "writing prediction to " << param.name_pred; CLITrain();
break;
std::unique_ptr<dmlc::Stream> fo( case kDumpModel:
dmlc::Stream::Create(param.name_pred.c_str(), "w")); CLIDumpModel();
dmlc::ostream os(fo.get()); break;
for (bst_float p : preds.ConstHostVector()) { case kPredict:
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) CLIPredict();
<< p << '\n'; break;
} }
// force flush before fo destruct. } catch (dmlc::Error const& e) {
os.set_stream(nullptr); xgboost::CLIError(e);
} return 1;
}
int CLIRunTask(int argc, char *argv[]) {
if (argc < 2) {
printf("Usage: <config>\n");
return 0; return 0;
} }
rabit::Init(argc, argv);
common::ConfigParser cp(argv[1]); ~CLI() {
auto cfg = cp.Parse(); rabit::Finalize();
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
cfg.emplace_back(std::string(name), std::string(val));
}
} }
CLIParam param; };
param.Configure(cfg);
switch (param.task) {
case kTrain: CLITrain(param); break;
case kDumpModel: CLIDumpModel(param); break;
case kPredict: CLIPredict(param); break;
}
rabit::Finalize();
return 0;
}
} // namespace xgboost } // namespace xgboost
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
return xgboost::CLIRunTask(argc, argv); try {
xgboost::CLI cli(argc, argv);
return cli.Run();
} catch (dmlc::Error const& e) {
// This captures only the initialization error.
xgboost::CLIError(e);
return 1;
}
return 0;
} }

View File

@ -7,8 +7,6 @@
#ifndef XGBOOST_COMMON_CONFIG_H_ #ifndef XGBOOST_COMMON_CONFIG_H_
#define XGBOOST_COMMON_CONFIG_H_ #define XGBOOST_COMMON_CONFIG_H_
#include <xgboost/logging.h>
#include <cstdio>
#include <string> #include <string>
#include <fstream> #include <fstream>
#include <istream> #include <istream>
@ -18,6 +16,8 @@
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include "xgboost/logging.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! /*!
@ -40,10 +40,16 @@ class ConfigParser {
std::string LoadConfigFile(const std::string& path) { std::string LoadConfigFile(const std::string& path) {
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary); std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
CHECK(fin) << "Failed to open: " << path; CHECK(fin) << "Failed to open config file: \"" << path << "\"";
std::string content{std::istreambuf_iterator<char>(fin), try {
std::istreambuf_iterator<char>()}; std::string content{std::istreambuf_iterator<char>(fin),
return content; std::istreambuf_iterator<char>()};
return content;
} catch (std::ios_base::failure const &e) {
LOG(FATAL) << "Failed to read config file: \"" << path << "\"\n"
<< e.what();
}
return "";
} }
/*! /*!

View File

@ -30,7 +30,6 @@
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
#include "../common/io.h"
#endif #endif
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)

View File

@ -18,7 +18,6 @@
#include <random> #include <random>
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "io.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {

View File

@ -114,5 +114,3 @@ XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -17,10 +17,12 @@
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "../common/io.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../data/ellpack_page.cuh" #include "../data/ellpack_page.cuh"
#include "param.h" #include "param.h"
#include "updater_gpu_common.cuh" #include "updater_gpu_common.cuh"
#include "constraints.cuh" #include "constraints.cuh"

View File

@ -5,6 +5,7 @@ import platform
import xgboost import xgboost
import subprocess import subprocess
import numpy import numpy
import json
class TestCLI(unittest.TestCase): class TestCLI(unittest.TestCase):
@ -27,20 +28,23 @@ data = {data_path}
eval[test] = {data_path} eval[test] = {data_path}
''' '''
def test_cli_model(self): curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__))) project_root = os.path.normpath(
project_root = os.path.normpath( os.path.join(curdir, os.path.pardir, os.path.pardir))
os.path.join(curdir, os.path.pardir, os.path.pardir))
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=project_root)
def get_exe(self):
if platform.system() == 'Windows': if platform.system() == 'Windows':
exe = 'xgboost.exe' exe = 'xgboost.exe'
else: else:
exe = 'xgboost' exe = 'xgboost'
exe = os.path.join(project_root, exe) exe = os.path.join(self.project_root, exe)
assert os.path.exists(exe) assert os.path.exists(exe)
return exe
def test_cli_model(self):
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root)
exe = self.get_exe()
seed = 1994 seed = 1994
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
@ -102,3 +106,48 @@ eval[test] = {data_path}
py_model_bin = fd.read() py_model_bin = fd.read()
assert hash(cli_model_bin) == hash(py_model_bin) assert hash(cli_model_bin) == hash(py_model_bin)
def test_cli_help(self):
exe = self.get_exe()
completed = subprocess.run([exe], stdout=subprocess.PIPE)
error_msg = completed.stdout.decode('utf-8')
ret = completed.returncode
assert ret == 1
assert error_msg.find('Usage') != -1
assert error_msg.find('eval[NAME]') != -1
completed = subprocess.run([exe, '-V'], stdout=subprocess.PIPE)
msg = completed.stdout.decode('utf-8')
assert msg.find('XGBoost') != -1
v = xgboost.__version__
if v.find('SNAPSHOT') != -1:
assert msg.split(':')[1].strip() == v.split('-')[0]
else:
assert msg.split(':')[1].strip() == v
def test_cli_model_json(self):
exe = self.get_exe()
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.project_root)
seed = 1994
with tempfile.TemporaryDirectory() as tmpdir:
model_out_cli = os.path.join(
tmpdir, 'test_load_cli_model-cli.json')
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
train_conf = self.template.format(data_path=data_path,
seed=seed,
task='train',
model_in='NULL',
model_out=model_out_cli,
test_path='NULL',
name_pred='NULL')
with open(config_path, 'w') as fd:
fd.write(train_conf)
subprocess.run([exe, config_path])
with open(model_out_cli, 'r') as fd:
model = json.load(fd)
assert model['learner']['gradient_booster']['name'] == 'gbtree'