xgboost/src/cli_main.cc

531 lines
16 KiB
C++

/*!
* Copyright 2014-2020 by Contributors
* \file cli_main.cc
* \brief The command line interface program of xgboost.
* This file is not included in dynamic library.
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include <dmlc/timer.h>
#include <xgboost/learner.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <iomanip>
#include <ctime>
#include <string>
#include <cstdio>
#include <cstring>
#include <vector>
#include "collective/communicator-inl.h"
#include "common/common.h"
#include "common/config.h"
#include "common/io.h"
#include "common/version.h"
#include "c_api/c_api_utils.h"
namespace xgboost {
enum CLITask {
kTrain = 0,
kDumpModel = 1,
kPredict = 2
};
struct CLIParam : public XGBoostParameter<CLIParam> {
/*! \brief the task name */
int task;
/*! \brief whether evaluate training statistics */
bool eval_train;
/*! \brief number of boosting iterations */
int num_round;
/*! \brief the period to save the model, 0 means only save the final round model */
int save_period;
/*! \brief the path of training set */
std::string train_path;
/*! \brief path of test dataset */
std::string test_path;
/*! \brief the path of test model file, or file to restart training */
std::string model_in;
/*! \brief the path of final model file, to be saved */
std::string model_out;
/*! \brief the path of directory containing the saved models */
std::string model_dir;
/*! \brief name of predict file */
std::string name_pred;
/*! \brief data split mode */
int dsplit;
/*!\brief limit number of trees in prediction */
int ntree_limit;
int iteration_begin;
int iteration_end;
/*!\brief whether to directly output margin value */
bool pred_margin;
/*! \brief whether dump statistics along with model */
int dump_stats;
/*! \brief what format to dump the model in */
std::string dump_format;
/*! \brief name of feature map */
std::string name_fmap;
/*! \brief name of dump file */
std::string name_dump;
/*! \brief the paths of validation data sets */
std::vector<std::string> eval_data_paths;
/*! \brief the names of the evaluation data used in output log */
std::vector<std::string> eval_data_names;
/*! \brief all the configurations */
std::vector<std::pair<std::string, std::string> > cfg;
static constexpr char const* const kNull = "NULL";
// declare parameters
DMLC_DECLARE_PARAMETER(CLIParam) {
// NOTE: declare everything except eval_data_paths.
DMLC_DECLARE_FIELD(task).set_default(kTrain)
.add_enum("train", kTrain)
.add_enum("dump", kDumpModel)
.add_enum("pred", kPredict)
.describe("Task to be performed by the CLI program.");
DMLC_DECLARE_FIELD(eval_train).set_default(false)
.describe("Whether evaluate on training data during training.");
DMLC_DECLARE_FIELD(num_round).set_default(10).set_lower_bound(1)
.describe("Number of boosting iterations");
DMLC_DECLARE_FIELD(save_period).set_default(0).set_lower_bound(0)
.describe("The period to save the model, 0 means only save final model.");
DMLC_DECLARE_FIELD(train_path).set_default("NULL")
.describe("Training data path.");
DMLC_DECLARE_FIELD(test_path).set_default("NULL")
.describe("Test data path.");
DMLC_DECLARE_FIELD(model_in).set_default("NULL")
.describe("Input model path, if any.");
DMLC_DECLARE_FIELD(model_out).set_default("NULL")
.describe("Output model path, if any.");
DMLC_DECLARE_FIELD(model_dir).set_default("./")
.describe("Output directory of period checkpoint.");
DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
.describe("Name of the prediction file.");
DMLC_DECLARE_FIELD(dsplit).set_default(0)
.add_enum("row", 0)
.add_enum("col", 1)
.describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
DMLC_DECLARE_FIELD(iteration_begin).set_default(0).set_lower_bound(0)
.describe("Begining of boosted tree iteration used for prediction.");
DMLC_DECLARE_FIELD(iteration_end).set_default(0).set_lower_bound(0)
.describe("End of boosted tree iteration used for prediction. 0 means all the trees.");
DMLC_DECLARE_FIELD(pred_margin).set_default(false)
.describe("Whether to predict margin value instead of probability.");
DMLC_DECLARE_FIELD(dump_stats).set_default(false)
.describe("Whether dump the model statistics.");
DMLC_DECLARE_FIELD(dump_format).set_default("text")
.describe("What format to dump the model in.");
DMLC_DECLARE_FIELD(name_fmap).set_default("NULL")
.describe("Name of the feature map file.");
DMLC_DECLARE_FIELD(name_dump).set_default("dump.txt")
.describe("Name of the output dump text file.");
// alias
DMLC_DECLARE_ALIAS(train_path, data);
DMLC_DECLARE_ALIAS(test_path, test:data);
DMLC_DECLARE_ALIAS(name_fmap, fmap);
}
// customized configure function of CLIParam
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
// Don't copy the configuration to enable parameter validation.
auto unknown_cfg = this->UpdateAllowUnknown(_cfg);
this->cfg.emplace_back("validate_parameters", "True");
for (const auto& kv : unknown_cfg) {
if (!strncmp("eval[", kv.first.c_str(), 5)) {
char evname[256];
CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1)
<< "must specify evaluation name for display";
eval_data_names.emplace_back(evname);
eval_data_paths.push_back(kv.second);
} else {
this->cfg.emplace_back(kv);
}
}
// constraint.
if (name_pred == "stdout") {
save_period = 0;
}
}
};
constexpr char const* const CLIParam::kNull;
DMLC_REGISTER_PARAMETER(CLIParam);
std::string CliHelp() {
return "Use xgboost -h for showing help information.\n";
}
void CLIError(dmlc::Error const& e) {
std::cerr << "Error running xgboost:\n\n"
<< e.what() << "\n"
<< CliHelp()
<< std::endl;
}
class CLI {
CLIParam param_;
std::unique_ptr<Learner> learner_;
enum Print {
kNone,
kVersion,
kHelp
} print_info_ {kNone};
void ResetLearner(std::vector<std::shared_ptr<DMatrix>> const &matrices) {
learner_.reset(Learner::Create(matrices));
if (param_.model_in != CLIParam::kNull) {
this->LoadModel(param_.model_in, learner_.get());
learner_->SetParams(param_.cfg);
} else {
learner_->SetParams(param_.cfg);
}
learner_->Configure();
}
void CLITrain() {
const double tstart_data_load = dmlc::GetTime();
if (collective::IsDistributed()) {
std::string pname = collective::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
static_cast<DataSplitMode>(param_.dsplit)));
std::vector<std::shared_ptr<DMatrix>> deval;
std::vector<std::shared_ptr<DMatrix>> cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
deval.emplace_back(std::shared_ptr<DMatrix>(
DMatrix::Load(param_.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
static_cast<DataSplitMode>(param_.dsplit))));
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param_.eval_data_names;
if (param_.eval_train) {
eval_datasets.push_back(dtrain);
eval_data_names.emplace_back("train");
}
// initialize the learner.
this->ResetLearner(cache_mats);
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load
<< " sec";
// start training.
const double start = dmlc::GetTime();
int32_t version = 0;
for (int i = version / 2; i < param_.num_round; ++i) {
double elapsed = dmlc::GetTime() - start;
if (version % 2 == 0) {
LOG(INFO) << "boosting round " << i << ", " << elapsed
<< " sec elapsed";
learner_->UpdateOneIter(i, dtrain);
version += 1;
}
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (collective::IsDistributed()) {
if (collective::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
collective::GetRank() == 0) {
std::ostringstream os;
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
this->SaveModel(os.str(), learner_.get());
}
version += 1;
}
LOG(INFO) << "Complete Training loop time: " << dmlc::GetTime() - start
<< " sec";
// always save final round
if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) &&
collective::GetRank() == 0) {
std::ostringstream os;
if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< param_.num_round << ".model";
} else {
os << param_.model_out;
}
this->SaveModel(os.str(), learner_.get());
}
double elapsed = dmlc::GetTime() - start;
LOG(INFO) << "update end, " << elapsed << " sec in all";
}
void CLIDumpModel() {
FeatureMap fmap;
if (param_.name_fmap != CLIParam::kNull) {
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(param_.name_fmap.c_str(), "r"));
dmlc::istream is(fs.get());
fmap.LoadText(is);
}
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for dump";
this->ResetLearner({});
// dump data
std::vector<std::string> dump =
learner_->DumpModel(fmap, param_.dump_stats, param_.dump_format);
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param_.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get());
if (param_.dump_format == "json") {
os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) {
os << "," << std::endl;
}
os << dump[i]; // Dump the previously generated JSON here
}
os << std::endl << "]" << std::endl;
} else {
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
}
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void CLIPredict() {
CHECK_NE(param_.test_path, CLIParam::kNull)
<< "Test dataset parameter test:data must be specified.";
// load data
std::shared_ptr<DMatrix> dtest(DMatrix::Load(
param_.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
static_cast<DataSplitMode>(param_.dsplit)));
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
this->ResetLearner({});
LOG(INFO) << "Start prediction...";
HostDeviceVector<bst_float> preds;
if (param_.ntree_limit != 0) {
param_.iteration_end = GetIterationFromTreeLimit(param_.ntree_limit, learner_.get());
LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_begin` and "
"`iteration_end` instead.";
}
learner_->Predict(dtest, param_.pred_margin, &preds, param_.iteration_begin,
param_.iteration_end);
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param_.name_pred.c_str(), "w"));
dmlc::ostream os(fo.get());
for (bst_float p : preds.ConstHostVector()) {
os << std::setprecision(std::numeric_limits<bst_float>::max_digits10) << p
<< '\n';
}
// force flush before fo destruct.
os.set_stream(nullptr);
}
void LoadModel(std::string const& path, Learner* learner) const {
if (common::FileExtension(path) == "json") {
auto str = common::LoadSequentialFile(path);
CHECK_GT(str.size(), 2);
CHECK_EQ(str[0], '{');
Json in{Json::Load({str.c_str(), str.size()})};
learner->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));
learner->LoadModel(fi.get());
}
}
void SaveModel(std::string const& path, Learner* learner) const {
learner->Configure();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path.c_str(), "w"));
if (common::FileExtension(path) == "json") {
Json out{Object()};
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);
fo->Write(str.c_str(), str.size());
} else {
learner->SaveModel(fo.get());
}
}
void PrintHelp() const {
std::cout << "Usage: xgboost [ -h ] [ -V ] [ config file ] [ arguments ]" << std::endl;
std::stringstream ss;
ss << R"(
Options and arguments:
-h, --help
Print this message.
-V, --version
Print XGBoost version.
arguments
Extra parameters that are not specified in config file, see below.
Config file specifies the configuration for both training and testing. Each line
containing the [attribute] = [value] configuration.
General XGBoost parameters:
https://xgboost.readthedocs.io/en/latest/parameter.html
Command line interface specfic parameters:
)";
std::string help = param_.__DOC__();
auto splited = common::Split(help, '\n');
for (auto str : splited) {
ss << " " << str << '\n';
}
ss << R"( eval[NAME]: string, optional, default='NULL'
Path to evaluation data, with NAME as data name.
)";
ss << R"(
Example: train.conf
# General parameters
booster = gbtree
objective = reg:squarederror
eta = 1.0
gamma = 1.0
seed = 0
min_child_weight = 0
max_depth = 3
# Training arguments for CLI.
num_round = 2
save_period = 0
data = "demo/data/agaricus.txt.train?format=libsvm"
eval[test] = "demo/data/agaricus.txt.test?format=libsvm"
See demo/ directory in XGBoost for more examples.
)";
std::cout << ss.str() << std::endl;
}
void PrintVersion() const {
auto ver = Version::String(Version::Self());
std::cout << "XGBoost: " << ver << std::endl;
}
public:
CLI(int argc, char* argv[]) {
if (argc < 2) {
this->PrintHelp();
exit(1);
}
for (int i = 0; i < argc; ++i) {
std::string str {argv[i]};
if (str == "-h" || str == "--help") {
print_info_ = kHelp;
break;
} else if (str == "-V" || str == "--version") {
print_info_ = kVersion;
break;
}
}
if (print_info_ != kNone) {
return;
}
std::string config_path = argv[1];
common::ConfigParser cp(config_path);
auto cfg = cp.Parse();
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
cfg.emplace_back(std::string(name), std::string(val));
}
}
// Initialize the collective communicator.
Json json{JsonObject()};
for (auto& kv : cfg) {
json[kv.first] = String(kv.second);
}
collective::Init(json);
param_.Configure(cfg);
}
int Run() {
switch (this->print_info_) {
case kNone:
break;
case kVersion: {
this->PrintVersion();
return 0;
}
case kHelp: {
this->PrintHelp();
return 0;
}
}
try {
switch (param_.task) {
case kTrain:
CLITrain();
break;
case kDumpModel:
CLIDumpModel();
break;
case kPredict:
CLIPredict();
break;
}
} catch (dmlc::Error const& e) {
xgboost::CLIError(e);
return 1;
}
return 0;
}
~CLI() {
collective::Finalize();
}
};
} // namespace xgboost
int main(int argc, char *argv[]) {
try {
xgboost::CLI cli(argc, argv);
return cli.Run();
} catch (dmlc::Error const& e) {
// This captures only the initialization error.
xgboost::CLIError(e);
return 1;
}
return 0;
}