Refactor the CLI. (#5574)
* Enable parameter validation. * Enable JSON. * Catch `dmlc::Error`. * Show help message.
This commit is contained in:
parent
7d93932423
commit
c90457f489
@ -30,11 +30,11 @@ General Parameters
|
||||
is displayed as warning message. If there's unexpected behaviour, please try to
|
||||
increase value of verbosity.
|
||||
|
||||
* ``validate_parameters`` [default to false, except for Python and R interface]
|
||||
* ``validate_parameters`` [default to false, except for Python, R and CLI interface]
|
||||
|
||||
- When set to True, XGBoost will perform validation of input parameters to check whether
|
||||
a parameter is used or not. The feature is still experimental. It's expected to have
|
||||
some false positives, especially when used with Scikit-Learn interface.
|
||||
some false positives.
|
||||
|
||||
* ``nthread`` [default to maximum number of threads available if not set]
|
||||
|
||||
|
||||
559
src/cli_main.cc
559
src/cli_main.cc
@ -4,28 +4,29 @@
|
||||
* \brief The command line interface program of xgboost.
|
||||
* This file is not included in dynamic library.
|
||||
*/
|
||||
// Copyright 2014 by Contributors
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <dmlc/timer.h>
|
||||
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/parameter.h>
|
||||
|
||||
#include <dmlc/timer.h>
|
||||
#include <iomanip>
|
||||
#include <ctime>
|
||||
#include <string>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "./common/common.h"
|
||||
#include "./common/config.h"
|
||||
#include "common/common.h"
|
||||
#include "common/config.h"
|
||||
#include "common/io.h"
|
||||
#include "common/version.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
enum CLITask {
|
||||
kTrain = 0,
|
||||
kDumpModel = 1,
|
||||
@ -74,6 +75,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
/*! \brief all the configurations */
|
||||
std::vector<std::pair<std::string, std::string> > cfg;
|
||||
|
||||
static constexpr char const* const kNull = "NULL";
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(CLIParam) {
|
||||
// NOTE: declare everything except eval_data_paths.
|
||||
@ -124,15 +127,18 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
}
|
||||
// customized configure function of CLIParam
|
||||
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
|
||||
this->cfg = _cfg;
|
||||
this->UpdateAllowUnknown(_cfg);
|
||||
for (const auto& kv : _cfg) {
|
||||
// Don't copy the configuration to enable parameter validation.
|
||||
auto unknown_cfg = this->UpdateAllowUnknown(_cfg);
|
||||
this->cfg.emplace_back("validate_parameters", "True");
|
||||
for (const auto& kv : unknown_cfg) {
|
||||
if (!strncmp("eval[", kv.first.c_str(), 5)) {
|
||||
char evname[256];
|
||||
CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1)
|
||||
<< "must specify evaluation name for display";
|
||||
eval_data_names.emplace_back(evname);
|
||||
eval_data_paths.push_back(kv.second);
|
||||
} else {
|
||||
this->cfg.emplace_back(kv);
|
||||
}
|
||||
}
|
||||
// constraint.
|
||||
@ -145,221 +151,376 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
}
|
||||
};
|
||||
|
||||
constexpr char const* const CLIParam::kNull;
|
||||
|
||||
DMLC_REGISTER_PARAMETER(CLIParam);
|
||||
|
||||
void CLITrain(const CLIParam& param) {
|
||||
const double tstart_data_load = dmlc::GetTime();
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(
|
||||
DMatrix::Load(
|
||||
param.train_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2));
|
||||
std::vector<std::shared_ptr<DMatrix> > deval;
|
||||
std::vector<std::shared_ptr<DMatrix> > cache_mats;
|
||||
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
|
||||
cache_mats.push_back(dtrain);
|
||||
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
|
||||
deval.emplace_back(
|
||||
std::shared_ptr<DMatrix>(DMatrix::Load(
|
||||
param.eval_data_paths[i],
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2)));
|
||||
eval_datasets.push_back(deval.back());
|
||||
cache_mats.push_back(deval.back());
|
||||
}
|
||||
std::vector<std::string> eval_data_names = param.eval_data_names;
|
||||
if (param.eval_train) {
|
||||
eval_datasets.push_back(dtrain);
|
||||
eval_data_names.emplace_back("train");
|
||||
}
|
||||
// initialize the learner.
|
||||
std::unique_ptr<Learner> learner(Learner::Create(cache_mats));
|
||||
int version = rabit::LoadCheckPoint(learner.get());
|
||||
if (version == 0) {
|
||||
// initialize the model if needed.
|
||||
if (param.model_in != "NULL") {
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->LoadModel(fi.get());
|
||||
learner->SetParams(param.cfg);
|
||||
} else {
|
||||
learner->SetParams(param.cfg);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
|
||||
std::string CliHelp() {
|
||||
return "Use xgboost -h for showing help information.\n";
|
||||
}
|
||||
|
||||
// start training.
|
||||
const double start = dmlc::GetTime();
|
||||
for (int i = version / 2; i < param.num_round; ++i) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
if (version % 2 == 0) {
|
||||
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
|
||||
learner->UpdateOneIter(i, dtrain);
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner.get());
|
||||
void CLIError(dmlc::Error const& e) {
|
||||
std::cerr << "Error running xgboost:\n\n"
|
||||
<< e.what() << "\n"
|
||||
<< CliHelp()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
class CLI {
|
||||
CLIParam param_;
|
||||
std::unique_ptr<Learner> learner_;
|
||||
enum Print {
|
||||
kNone,
|
||||
kVersion,
|
||||
kHelp
|
||||
} print_info_ {kNone};
|
||||
|
||||
int ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
|
||||
learner_.reset(Learner::Create(matrices));
|
||||
int version = rabit::LoadCheckPoint(learner_.get());
|
||||
if (version == 0) {
|
||||
if (param_.model_in != CLIParam::kNull) {
|
||||
this->LoadModel(param_.model_in, learner_.get());
|
||||
learner_->SetParams(param_.cfg);
|
||||
} else {
|
||||
rabit::CheckPoint(learner.get());
|
||||
learner_->SetParams(param_.cfg);
|
||||
}
|
||||
}
|
||||
learner_->Configure();
|
||||
return version;
|
||||
}
|
||||
|
||||
void CLITrain() {
|
||||
const double tstart_data_load = dmlc::GetTime();
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
||||
param_.train_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param_.dsplit == 2));
|
||||
std::vector<std::shared_ptr<DMatrix>> deval;
|
||||
std::vector<std::shared_ptr<DMatrix>> cache_mats;
|
||||
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
|
||||
cache_mats.push_back(dtrain);
|
||||
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
|
||||
deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load(
|
||||
param_.eval_data_paths[i],
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param_.dsplit == 2)));
|
||||
eval_datasets.push_back(deval.back());
|
||||
cache_mats.push_back(deval.back());
|
||||
}
|
||||
std::vector<std::string> eval_data_names = param_.eval_data_names;
|
||||
if (param_.eval_train) {
|
||||
eval_datasets.push_back(dtrain);
|
||||
eval_data_names.emplace_back("train");
|
||||
}
|
||||
// initialize the learner.
|
||||
int32_t version = this->ResetLearner(cache_mats);
|
||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
|
||||
<< " sec";
|
||||
|
||||
// start training.
|
||||
const double start = dmlc::GetTime();
|
||||
for (int i = version / 2; i < param_.num_round; ++i) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
if (version % 2 == 0) {
|
||||
LOG(INFO) << "boosting round " << i << ", " << elapsed
|
||||
<< " sec elapsed";
|
||||
learner_->UpdateOneIter(i, dtrain);
|
||||
if (learner_->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner_.get());
|
||||
} else {
|
||||
rabit::CheckPoint(learner_.get());
|
||||
}
|
||||
version += 1;
|
||||
}
|
||||
CHECK_EQ(version, rabit::VersionNumber());
|
||||
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||
if (rabit::IsDistributed()) {
|
||||
if (rabit::GetRank() == 0) {
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
|
||||
rabit::GetRank() == 0) {
|
||||
std::ostringstream os;
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
<< i + 1 << ".model";
|
||||
this->SaveModel(os.str(), learner_.get());
|
||||
}
|
||||
|
||||
if (learner_->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner_.get());
|
||||
} else {
|
||||
rabit::CheckPoint(learner_.get());
|
||||
}
|
||||
version += 1;
|
||||
CHECK_EQ(version, rabit::VersionNumber());
|
||||
}
|
||||
CHECK_EQ(version, rabit::VersionNumber());
|
||||
std::string res = learner->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||
if (rabit::IsDistributed()) {
|
||||
if (rabit::GetRank() == 0) {
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
if (param.save_period != 0 &&
|
||||
(i + 1) % param.save_period == 0 &&
|
||||
rabit::GetRank() == 0) {
|
||||
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
|
||||
<< " sec";
|
||||
// always save final round
|
||||
if ((param_.save_period == 0 ||
|
||||
param_.num_round % param_.save_period != 0) &&
|
||||
param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) {
|
||||
std::ostringstream os;
|
||||
os << param.model_dir << '/'
|
||||
<< std::setfill('0') << std::setw(4)
|
||||
<< i + 1 << ".model";
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(os.str().c_str(), "w"));
|
||||
if (param_.model_out == CLIParam::kNull) {
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
<< param_.num_round << ".model";
|
||||
} else {
|
||||
os << param_.model_out;
|
||||
}
|
||||
this->SaveModel(os.str(), learner_.get());
|
||||
}
|
||||
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
LOG(INFO) << "update end, " << elapsed << " sec in all";
|
||||
}
|
||||
|
||||
void CLIDumpModel() {
|
||||
FeatureMap fmap;
|
||||
if (param_.name_fmap != CLIParam::kNull) {
|
||||
std::unique_ptr<dmlc::Stream> fs(
|
||||
dmlc::Stream::Create(param_.name_fmap.c_str(), "r"));
|
||||
dmlc::istream is(fs.get());
|
||||
fmap.LoadText(is);
|
||||
}
|
||||
// load model
|
||||
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for dump";
|
||||
this->ResetLearner({});
|
||||
|
||||
// dump data
|
||||
std::vector<std::string> dump =
|
||||
learner_->DumpModel(fmap, param_.dump_stats, param_.dump_format);
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param_.name_dump.c_str(), "w"));
|
||||
dmlc::ostream os(fo.get());
|
||||
if (param_.dump_format == "json") {
|
||||
os << "[" << std::endl;
|
||||
for (size_t i = 0; i < dump.size(); ++i) {
|
||||
if (i != 0) {
|
||||
os << "," << std::endl;
|
||||
}
|
||||
os << dump[i]; // Dump the previously generated JSON here
|
||||
}
|
||||
os << std::endl << "]" << std::endl;
|
||||
} else {
|
||||
for (size_t i = 0; i < dump.size(); ++i) {
|
||||
os << "booster[" << i << "]:\n";
|
||||
os << dump[i];
|
||||
}
|
||||
}
|
||||
// force flush before fo destruct.
|
||||
os.set_stream(nullptr);
|
||||
}
|
||||
|
||||
void CLIPredict() {
|
||||
CHECK_NE(param_.test_path, CLIParam::kNull)
|
||||
<< "Test dataset parameter test:data must be specified.";
|
||||
// load data
|
||||
std::shared_ptr<DMatrix> dtest(DMatrix::Load(
|
||||
param_.test_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param_.dsplit == 2));
|
||||
// load model
|
||||
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
|
||||
this->ResetLearner({});
|
||||
|
||||
LOG(INFO) << "Start prediction...";
|
||||
HostDeviceVector<bst_float> preds;
|
||||
learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit);
|
||||
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param_.name_pred.c_str(), "w"));
|
||||
dmlc::ostream os(fo.get());
|
||||
for (bst_float p : preds.ConstHostVector()) {
|
||||
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) << p
|
||||
<< '\n';
|
||||
}
|
||||
// force flush before fo destruct.
|
||||
os.set_stream(nullptr);
|
||||
}
|
||||
|
||||
void LoadModel(std::string const& path, Learner* learner) const {
|
||||
if (common::FileExtension(path) == "json") {
|
||||
auto str = common::LoadSequentialFile(path);
|
||||
CHECK_GT(str.size(), 2);
|
||||
CHECK_EQ(str[0], '{');
|
||||
Json in{Json::Load({str.c_str(), str.size()})};
|
||||
learner->LoadModel(in);
|
||||
} else {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));
|
||||
learner->LoadModel(fi.get());
|
||||
}
|
||||
}
|
||||
|
||||
void SaveModel(std::string const& path, Learner* learner) const {
|
||||
learner->Configure();
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path.c_str(), "w"));
|
||||
if (common::FileExtension(path) == "json") {
|
||||
Json out{Object()};
|
||||
learner->SaveModel(&out);
|
||||
std::string str;
|
||||
Json::Dump(out, &str);
|
||||
fo->Write(str.c_str(), str.size());
|
||||
} else {
|
||||
learner->SaveModel(fo.get());
|
||||
}
|
||||
}
|
||||
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner.get());
|
||||
} else {
|
||||
rabit::CheckPoint(learner.get());
|
||||
void PrintHelp() const {
|
||||
std::cout << "Usage: xgboost [ -h ] [ -V ] [ config file ] [ arguments ]" << std::endl;
|
||||
std::stringstream ss;
|
||||
ss << R"(
|
||||
Options and arguments:
|
||||
|
||||
-h, --help
|
||||
Print this message.
|
||||
|
||||
-V, --version
|
||||
Print XGBoost version.
|
||||
|
||||
arguments
|
||||
Extra parameters that are not specified in config file, see below.
|
||||
|
||||
Config file specifies the configuration for both training and testing. Each line
|
||||
containing the [attribute] = [value] configuration.
|
||||
|
||||
General XGBoost parameters:
|
||||
|
||||
https://xgboost.readthedocs.io/en/latest/parameter.html
|
||||
|
||||
Command line interface specfic parameters:
|
||||
|
||||
)";
|
||||
|
||||
std::string help = param_.__DOC__();
|
||||
auto splited = common::Split(help, '\n');
|
||||
for (auto str : splited) {
|
||||
ss << " " << str << '\n';
|
||||
}
|
||||
version += 1;
|
||||
CHECK_EQ(version, rabit::VersionNumber());
|
||||
ss << R"( eval[NAME]: string, optional, default='NULL'
|
||||
Path to evaluation data, with NAME as data name.
|
||||
)";
|
||||
|
||||
ss << R"(
|
||||
Example: train.conf
|
||||
|
||||
# General parameters
|
||||
booster = gbtree
|
||||
objective = reg:squarederror
|
||||
eta = 1.0
|
||||
gamma = 1.0
|
||||
seed = 0
|
||||
min_child_weight = 0
|
||||
max_depth = 3
|
||||
|
||||
# Training arguments for CLI.
|
||||
num_round = 2
|
||||
save_period = 0
|
||||
data = "demo/data/agaricus.txt.train?format=libsvm"
|
||||
eval[test] = "demo/data/agaricus.txt.test?format=libsvm"
|
||||
|
||||
See demo/ directory in XGBoost for more examples.
|
||||
)";
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start << " sec";
|
||||
// always save final round
|
||||
if ((param.save_period == 0 || param.num_round % param.save_period != 0) &&
|
||||
param.model_out != "NONE" &&
|
||||
rabit::GetRank() == 0) {
|
||||
std::ostringstream os;
|
||||
if (param.model_out == "NULL") {
|
||||
os << param.model_dir << '/'
|
||||
<< std::setfill('0') << std::setw(4)
|
||||
<< param.num_round << ".model";
|
||||
} else {
|
||||
os << param.model_out;
|
||||
|
||||
void PrintVersion() const {
|
||||
auto ver = Version::String(Version::Self());
|
||||
std::cout << "XGBoost: " << ver << std::endl;
|
||||
}
|
||||
|
||||
public:
|
||||
CLI(int argc, char* argv[]) {
|
||||
if (argc < 2) {
|
||||
this->PrintHelp();
|
||||
exit(1);
|
||||
}
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(os.str().c_str(), "w"));
|
||||
learner->SaveModel(fo.get());
|
||||
}
|
||||
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
LOG(INFO) << "update end, " << elapsed << " sec in all";
|
||||
}
|
||||
|
||||
void CLIDumpModel(const CLIParam& param) {
|
||||
FeatureMap fmap;
|
||||
if (param.name_fmap != "NULL") {
|
||||
std::unique_ptr<dmlc::Stream> fs(
|
||||
dmlc::Stream::Create(param.name_fmap.c_str(), "r"));
|
||||
dmlc::istream is(fs.get());
|
||||
fmap.LoadText(is);
|
||||
}
|
||||
// load model
|
||||
CHECK_NE(param.model_in, "NULL")
|
||||
<< "Must specify model_in for dump";
|
||||
std::unique_ptr<Learner> learner(Learner::Create({}));
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->SetParams(param.cfg);
|
||||
learner->LoadModel(fi.get());
|
||||
// dump data
|
||||
std::vector<std::string> dump = learner->DumpModel(
|
||||
fmap, param.dump_stats, param.dump_format);
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param.name_dump.c_str(), "w"));
|
||||
dmlc::ostream os(fo.get());
|
||||
if (param.dump_format == "json") {
|
||||
os << "[" << std::endl;
|
||||
for (size_t i = 0; i < dump.size(); ++i) {
|
||||
if (i != 0) os << "," << std::endl;
|
||||
os << dump[i]; // Dump the previously generated JSON here
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
std::string str {argv[i]};
|
||||
if (str == "-h" || str == "--help") {
|
||||
print_info_ = kHelp;
|
||||
break;
|
||||
} else if (str == "-V" || str == "--version") {
|
||||
print_info_ = kVersion;
|
||||
break;
|
||||
}
|
||||
}
|
||||
os << std::endl << "]" << std::endl;
|
||||
} else {
|
||||
for (size_t i = 0; i < dump.size(); ++i) {
|
||||
os << "booster[" << i << "]:\n";
|
||||
os << dump[i];
|
||||
if (print_info_ != kNone) {
|
||||
return;
|
||||
}
|
||||
|
||||
rabit::Init(argc, argv);
|
||||
std::string config_path = argv[1];
|
||||
|
||||
common::ConfigParser cp(config_path);
|
||||
auto cfg = cp.Parse();
|
||||
|
||||
for (int i = 2; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
cfg.emplace_back(std::string(name), std::string(val));
|
||||
}
|
||||
}
|
||||
|
||||
param_.Configure(cfg);
|
||||
}
|
||||
// force flush before fo destruct.
|
||||
os.set_stream(nullptr);
|
||||
}
|
||||
|
||||
void CLIPredict(const CLIParam& param) {
|
||||
CHECK_NE(param.test_path, "NULL")
|
||||
<< "Test dataset parameter test:data must be specified.";
|
||||
// load data
|
||||
std::shared_ptr<DMatrix> dtest(
|
||||
DMatrix::Load(
|
||||
param.test_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2));
|
||||
// load model
|
||||
CHECK_NE(param.model_in, "NULL")
|
||||
<< "Must specify model_in for predict";
|
||||
std::unique_ptr<Learner> learner(Learner::Create({}));
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->LoadModel(fi.get());
|
||||
learner->SetParams(param.cfg);
|
||||
int Run() {
|
||||
switch (this->print_info_) {
|
||||
case kNone:
|
||||
break;
|
||||
case kVersion: {
|
||||
this->PrintVersion();
|
||||
return 0;
|
||||
}
|
||||
case kHelp: {
|
||||
this->PrintHelp();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "start prediction...";
|
||||
HostDeviceVector<bst_float> preds;
|
||||
learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit);
|
||||
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param.name_pred.c_str(), "w"));
|
||||
dmlc::ostream os(fo.get());
|
||||
for (bst_float p : preds.ConstHostVector()) {
|
||||
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10)
|
||||
<< p << '\n';
|
||||
}
|
||||
// force flush before fo destruct.
|
||||
os.set_stream(nullptr);
|
||||
}
|
||||
|
||||
int CLIRunTask(int argc, char *argv[]) {
|
||||
if (argc < 2) {
|
||||
printf("Usage: <config>\n");
|
||||
try {
|
||||
switch (param_.task) {
|
||||
case kTrain:
|
||||
CLITrain();
|
||||
break;
|
||||
case kDumpModel:
|
||||
CLIDumpModel();
|
||||
break;
|
||||
case kPredict:
|
||||
CLIPredict();
|
||||
break;
|
||||
}
|
||||
} catch (dmlc::Error const& e) {
|
||||
xgboost::CLIError(e);
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
rabit::Init(argc, argv);
|
||||
|
||||
common::ConfigParser cp(argv[1]);
|
||||
auto cfg = cp.Parse();
|
||||
|
||||
for (int i = 2; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
cfg.emplace_back(std::string(name), std::string(val));
|
||||
}
|
||||
~CLI() {
|
||||
rabit::Finalize();
|
||||
}
|
||||
CLIParam param;
|
||||
param.Configure(cfg);
|
||||
|
||||
switch (param.task) {
|
||||
case kTrain: CLITrain(param); break;
|
||||
case kDumpModel: CLIDumpModel(param); break;
|
||||
case kPredict: CLIPredict(param); break;
|
||||
}
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
return xgboost::CLIRunTask(argc, argv);
|
||||
try {
|
||||
xgboost::CLI cli(argc, argv);
|
||||
return cli.Run();
|
||||
} catch (dmlc::Error const& e) {
|
||||
// This captures only the initialization error.
|
||||
xgboost::CLIError(e);
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -7,8 +7,6 @@
|
||||
#ifndef XGBOOST_COMMON_CONFIG_H_
|
||||
#define XGBOOST_COMMON_CONFIG_H_
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
@ -18,6 +16,8 @@
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
/*!
|
||||
@ -40,10 +40,16 @@ class ConfigParser {
|
||||
|
||||
std::string LoadConfigFile(const std::string& path) {
|
||||
std::ifstream fin(path, std::ios_base::in | std::ios_base::binary);
|
||||
CHECK(fin) << "Failed to open: " << path;
|
||||
std::string content{std::istreambuf_iterator<char>(fin),
|
||||
std::istreambuf_iterator<char>()};
|
||||
return content;
|
||||
CHECK(fin) << "Failed to open config file: \"" << path << "\"";
|
||||
try {
|
||||
std::string content{std::istreambuf_iterator<char>(fin),
|
||||
std::istreambuf_iterator<char>()};
|
||||
return content;
|
||||
} catch (std::ios_base::failure const &e) {
|
||||
LOG(FATAL) << "Failed to read config file: \"" << path << "\"\n"
|
||||
<< e.what();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
@ -30,7 +30,6 @@
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#include "../common/io.h"
|
||||
#endif
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
#include <random>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
@ -114,5 +114,3 @@ XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
|
||||
@ -17,10 +17,12 @@
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#include "../common/io.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "constraints.cuh"
|
||||
|
||||
@ -5,6 +5,7 @@ import platform
|
||||
import xgboost
|
||||
import subprocess
|
||||
import numpy
|
||||
import json
|
||||
|
||||
|
||||
class TestCLI(unittest.TestCase):
|
||||
@ -27,20 +28,23 @@ data = {data_path}
|
||||
eval[test] = {data_path}
|
||||
'''
|
||||
|
||||
def test_cli_model(self):
|
||||
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
project_root = os.path.normpath(
|
||||
os.path.join(curdir, os.path.pardir, os.path.pardir))
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=project_root)
|
||||
curdir = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||
project_root = os.path.normpath(
|
||||
os.path.join(curdir, os.path.pardir, os.path.pardir))
|
||||
|
||||
def get_exe(self):
|
||||
if platform.system() == 'Windows':
|
||||
exe = 'xgboost.exe'
|
||||
else:
|
||||
exe = 'xgboost'
|
||||
exe = os.path.join(project_root, exe)
|
||||
exe = os.path.join(self.project_root, exe)
|
||||
assert os.path.exists(exe)
|
||||
return exe
|
||||
|
||||
def test_cli_model(self):
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
exe = self.get_exe()
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@ -102,3 +106,48 @@ eval[test] = {data_path}
|
||||
py_model_bin = fd.read()
|
||||
|
||||
assert hash(cli_model_bin) == hash(py_model_bin)
|
||||
|
||||
def test_cli_help(self):
|
||||
exe = self.get_exe()
|
||||
completed = subprocess.run([exe], stdout=subprocess.PIPE)
|
||||
error_msg = completed.stdout.decode('utf-8')
|
||||
ret = completed.returncode
|
||||
assert ret == 1
|
||||
assert error_msg.find('Usage') != -1
|
||||
assert error_msg.find('eval[NAME]') != -1
|
||||
|
||||
completed = subprocess.run([exe, '-V'], stdout=subprocess.PIPE)
|
||||
msg = completed.stdout.decode('utf-8')
|
||||
assert msg.find('XGBoost') != -1
|
||||
v = xgboost.__version__
|
||||
if v.find('SNAPSHOT') != -1:
|
||||
assert msg.split(':')[1].strip() == v.split('-')[0]
|
||||
else:
|
||||
assert msg.split(':')[1].strip() == v
|
||||
|
||||
def test_cli_model_json(self):
|
||||
exe = self.get_exe()
|
||||
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
|
||||
root=self.project_root)
|
||||
seed = 1994
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model_out_cli = os.path.join(
|
||||
tmpdir, 'test_load_cli_model-cli.json')
|
||||
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')
|
||||
|
||||
train_conf = self.template.format(data_path=data_path,
|
||||
seed=seed,
|
||||
task='train',
|
||||
model_in='NULL',
|
||||
model_out=model_out_cli,
|
||||
test_path='NULL',
|
||||
name_pred='NULL')
|
||||
with open(config_path, 'w') as fd:
|
||||
fd.write(train_conf)
|
||||
|
||||
subprocess.run([exe, config_path])
|
||||
with open(model_out_cli, 'r') as fd:
|
||||
model = json.load(fd)
|
||||
|
||||
assert model['learner']['gradient_booster']['name'] == 'gbtree'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user