diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 36c64c5d0..33c919bdd 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -446,6 +446,23 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, bst_ulong *out_len, const char ***out_dump_array); +/*! + * \brief dump model, return array of strings representing model dump + * \param handle handle + * \param fmap name to fmap can be empty string + * \param with_stats whether to dump with statistics + * \param format the format to dump the model in + * \param out_len length of output array + * \param out_dump_array pointer to hold representing dump of each model + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, + const char *fmap, + int with_stats, + const char *format, + bst_ulong *out_len, + const char ***out_dump_array); + /*! * \brief dump model, return array of strings representing model dump * \param handle handle @@ -465,6 +482,27 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, bst_ulong *out_len, const char ***out_models); +/*! + * \brief dump model, return array of strings representing model dump + * \param handle handle + * \param fnum number of features + * \param fname names of features + * \param ftype types of features + * \param with_stats whether to dump with statistics + * \param format the format to dump the model in + * \param out_len length of output array + * \param out_models pointer to hold representing dump of each model + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, + int fnum, + const char **fname, + const char **ftype, + int with_stats, + const char *format, + bst_ulong *out_len, + const char ***out_models); + /*! * \brief Get string attribute from Booster. * \param handle handle diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 770e982b7..6a562c1f8 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -225,7 +225,7 @@ struct RowSet { * - Provide a dmlc::Parser and pass into the DMatrix::Create * - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER; * - This works best for user defined data input source, such as data-base, filesystem. - * - Provdie a DataSource, that can be passed to DMatrix::Create + * - Provide a DataSource, that can be passed to DMatrix::Create * This can be used to re-use inmemory data structure into DMatrix. */ class DMatrix { diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 6c10aa155..92e75cd42 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -108,12 +108,15 @@ class GradientBooster { std::vector* out_preds, unsigned ntree_limit = 0) = 0; /*! - * \brief dump the model to text format + * \brief dump the model in the requested format * \param fmap feature map that may help give interpretations of feature - * \param option extra option of the dump model + * \param with_stats extra statistics while dumping model + * \param format the format to dump the model in * \return a vector of dump for boosters. */ - virtual std::vector Dump2Text(const FeatureMap& fmap, int option) const = 0; + virtual std::vector DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const = 0; /*! * \brief create a gradient booster from given name * \param name name of gradient booster diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 10af8dd33..319593c83 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -140,12 +140,15 @@ class Learner : public rabit::Serializable { */ bool AllowLazyCheckPoint() const; /*! - * \brief dump the model in text format + * \brief dump the model in the requested format * \param fmap feature map that may help give interpretations of feature - * \param option extra option of the dump model + * \param with_stats extra statistics while dumping model + * \param format the format to dump the model in * \return a vector of dump for boosters. */ - std::vector Dump2Text(const FeatureMap& fmap, int option) const; + std::vector DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const; /*! * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is usually diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 168d1e936..d17d52b29 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -480,12 +480,15 @@ class RegTree: public TreeModel { */ inline int GetNext(int pid, float fvalue, bool is_unknown) const; /*! - * \brief dump model to text string - * \param fmap feature map of feature types + * \brief dump the model in the requested format as a text string + * \param fmap feature map that may help give interpretations of feature * \param with_stats whether dump out statistics as well + * \param format the format to dump the model in * \return the string of dumped model */ - std::string Dump2Text(const FeatureMap& fmap, bool with_stats) const; + std::string DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const; }; // implementations of inline functions diff --git a/plugin/README.md b/plugin/README.md index 56d973fd3..de146a9ec 100644 --- a/plugin/README.md +++ b/plugin/README.md @@ -2,16 +2,17 @@ XGBoost Plugins Modules ======================= This folder contains plugin modules to xgboost that can be optionally installed. The plugin system helps us to extend xgboost with additional features, -and add experimental features that may not yet ready to be included in main project. +and add experimental features that may not yet be ready to be included in the +main project. To include a certain plugin, say ```plugin_a```, you only need to add the following line to the config.mk. ```makefile -# Add plugin by include the plugin in config +# Add plugin by including the plugin in config.mk XGB_PLUGINS += plugin/plugin_a/plugin.mk ``` -Then rebuild libxgboost by typing make, you can get a new library with the plugin enabled. +Then rebuild libxgboost by typing ```make```, you can get a new library with the plugin enabled. Link Static XGBoost Library with Plugins ---------------------------------------- @@ -20,7 +21,7 @@ If you only use ```libxgboost.so```(this include python and other bindings), you can ignore this section. When you want to link ```libxgboost.a``` with additional plugins included, -you will need to enabled whole archeive via The following option. +you will need to enabled whole archive via The following option. ```bash --whole-archive libxgboost.a --no-whole-archive ``` @@ -30,3 +31,21 @@ Write Your Own Plugin You can plugin your own modules to xgboost by adding code to this folder, without modification to the main code repo. The [example](example) folder provides an example to write a plugin. + +List of register functions +-------------------------- +A plugin has to register a new functionality to xgboost to be able to use it. +The register macros available to plugin writers are: + + - XGBOOST_REGISTER_METRIC - Register an evaluation metric + - XGBOOST_REGISTER_GBM - Register a new gradient booster that learns through + gradient statistics + - XGBOOST_REGISTER_OBJECTIVE - Register a new objective function used by xgboost + - XGBOOST_REGISTER_TREE_UPDATER - Register a new tree-updater which updates + the tree given the gradient information + +And from dmlc-core: + + - DMLC_REGISTER_PARAMETER - Register a set of parameter for a specific usecase + - DMLC_REGISTER_DATA_PARSER - Register a data parser where the data can be + represented by a URL. This is used by DMatrix. diff --git a/plugin/example/README.md b/plugin/example/README.md index f0ff5478c..32b709047 100644 --- a/plugin/example/README.md +++ b/plugin/example/README.md @@ -2,10 +2,10 @@ XGBoost Plugin Example ====================== This folder provides an example of xgboost plugin. -There are three steps you need to to do to add plugin to xgboost +There are three steps you need to do to add a plugin to xgboost - Create your source .cc file, implement a new extension - In this example [custom_obj.cc](custom_obj.cc) -- Register this extension to xgboost via registration macr +- Register this extension to xgboost via a registration macro - In this example ```XGBOOST_REGISTER_OBJECTIVE``` in [this line](custom_obj.cc#L75) - Create a [plugin.mk](plugin.mk) on this folder diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a64d1e03e..ac0bd08f3 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1038,7 +1038,7 @@ class Booster(object): if need_close: fout.close() - def get_dump(self, fmap='', with_stats=False): + def get_dump(self, fmap='', with_stats=False, dump_format="text"): """ Returns the dump the model as a list of strings. """ @@ -1056,21 +1056,24 @@ class Booster(object): ftype = from_pystr_to_cstr(['q'] * flen) else: ftype = from_pystr_to_cstr(self.feature_types) - _check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle, - flen, - fname, - ftype, - int(with_stats), - ctypes.byref(length), - ctypes.byref(sarr))) + _check_call(_LIB.XGBoosterDumpModelExWithFeatures( + self.handle, + flen, + fname, + ftype, + int(with_stats), + c_str(dump_format), + ctypes.byref(length), + ctypes.byref(sarr))) else: if fmap != '' and not os.path.exists(fmap): raise ValueError("No such file: {0}".format(fmap)) - _check_call(_LIB.XGBoosterDumpModel(self.handle, - c_str(fmap), - int(with_stats), - ctypes.byref(length), - ctypes.byref(sarr))) + _check_call(_LIB.XGBoosterDumpModelEx(self.handle, + c_str(fmap), + int(with_stats), + c_str(dump_format), + ctypes.byref(length), + ctypes.byref(sarr))) res = from_cstr_to_pystr(sarr, length) return res diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 90a1278db..b118c8f20 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -662,13 +662,14 @@ inline void XGBoostDumpModelImpl( BoosterHandle handle, const FeatureMap& fmap, int with_stats, + const char *format, xgboost::bst_ulong* len, const char*** out_models) { std::vector& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str; std::vector& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp; Booster *bst = static_cast(handle); bst->LazyInit(); - str_vecs = bst->learner()->Dump2Text(fmap, with_stats != 0); + str_vecs = bst->learner()->DumpModel(fmap, with_stats != 0, format); charp_vecs.resize(str_vecs.size()); for (size_t i = 0; i < str_vecs.size(); ++i) { charp_vecs[i] = str_vecs[i].c_str(); @@ -681,6 +682,14 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, int with_stats, xgboost::bst_ulong* len, const char*** out_models) { + XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models); +} +XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, + const char* fmap, + int with_stats, + const char *format, + xgboost::bst_ulong* len, + const char*** out_models) { API_BEGIN(); FeatureMap featmap; if (strlen(fmap) != 0) { @@ -689,7 +698,7 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, dmlc::istream is(fs.get()); featmap.LoadText(is); } - XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models); + XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models); API_END(); } @@ -700,12 +709,23 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, int with_stats, xgboost::bst_ulong* len, const char*** out_models) { + XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats, + "text", len, out_models); +} +XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, + int fnum, + const char** fname, + const char** ftype, + int with_stats, + const char *format, + xgboost::bst_ulong* len, + const char*** out_models) { API_BEGIN(); FeatureMap featmap; for (int i = 0; i < fnum; ++i) { featmap.PushBack(i, fname[i], ftype[i]); } - XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models); + XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models); API_END(); } diff --git a/src/cli_main.cc b/src/cli_main.cc index cd01aa3ed..57791274f 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -27,7 +27,7 @@ namespace xgboost { enum CLITask { kTrain = 0, - kDump2Text = 1, + kDumpModel = 1, kPredict = 2 }; @@ -62,6 +62,8 @@ struct CLIParam : public dmlc::Parameter { bool pred_margin; /*! \brief whether dump statistics along with model */ int dump_stats; + /*! \brief what format to dump the model in */ + std::string dump_format; /*! \brief name of feature map */ std::string name_fmap; /*! \brief name of dump file */ @@ -78,7 +80,7 @@ struct CLIParam : public dmlc::Parameter { // NOTE: declare everything except eval_data_paths. DMLC_DECLARE_FIELD(task).set_default(kTrain) .add_enum("train", kTrain) - .add_enum("dump", kDump2Text) + .add_enum("dump", kDumpModel) .add_enum("pred", kPredict) .describe("Task to be performed by the CLI program."); DMLC_DECLARE_FIELD(silent).set_default(0).set_range(0, 2) @@ -112,6 +114,8 @@ struct CLIParam : public dmlc::Parameter { .describe("Whether to predict margin value instead of probability."); DMLC_DECLARE_FIELD(dump_stats).set_default(false) .describe("Whether dump the model statistics."); + DMLC_DECLARE_FIELD(dump_format).set_default("text") + .describe("What format to dump the model in."); DMLC_DECLARE_FIELD(name_fmap).set_default("NULL") .describe("Name of the feature map file."); DMLC_DECLARE_FIELD(name_dump).set_default("dump.txt") @@ -259,7 +263,7 @@ void CLITrain(const CLIParam& param) { } } -void CLIDump2Text(const CLIParam& param) { +void CLIDumpModel(const CLIParam& param) { FeatureMap fmap; if (param.name_fmap != "NULL") { std::unique_ptr fs( @@ -276,13 +280,23 @@ void CLIDump2Text(const CLIParam& param) { learner->Configure(param.cfg); learner->Load(fi.get()); // dump data - std::vector dump = learner->Dump2Text(fmap, param.dump_stats); + std::vector dump = learner->DumpModel( + fmap, param.dump_stats, param.dump_format); std::unique_ptr fo( dmlc::Stream::Create(param.name_dump.c_str(), "w")); dmlc::ostream os(fo.get()); - for (size_t i = 0; i < dump.size(); ++i) { - os << "booster[" << i << "]:\n"; - os << dump[i]; + if (param.dump_format == "json") { + os << "[" << std::endl; + for (size_t i = 0; i < dump.size(); ++i) { + if (i != 0) os << "," << std::endl; + os << dump[i]; // Dump the previously generated JSON here + } + os << std::endl << "]" << std::endl; + } else { + for (size_t i = 0; i < dump.size(); ++i) { + os << "booster[" << i << "]:\n"; + os << dump[i]; + } } // force flush before fo destruct. os.set_stream(nullptr); @@ -347,7 +361,7 @@ int CLIRunTask(int argc, char *argv[]) { switch (param.task) { case kTrain: CLITrain(param); break; - case kDump2Text: CLIDump2Text(param); break; + case kDumpModel: CLIDumpModel(param); break; case kPredict: CLIPredict(param); break; } rabit::Finalize(); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 02a3445ec..59bf501fd 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -43,22 +43,22 @@ class SparsePageSource : public DataSource { /*! * \brief Create source by taking data from parser. * \param src source parser. - * \param cache_prefix The cache_prefix of cache file location. + * \param cache_info The cache_info of cache file location. */ static void Create(dmlc::Parser* src, - const std::string& cache_prefix); + const std::string& cache_info); /*! * \brief Create source cache by copy content from DMatrix. - * \param cache_prefix The cache_prefix of cache file location. + * \param cache_info The cache_info of cache file location. */ static void Create(DMatrix* src, - const std::string& cache_prefix); + const std::string& cache_info); /*! * \brief Check if the cache file already exists. - * \param cache_prefix The cache prefix of files. + * \param cache_info The cache prefix of files. * \return Whether cache file already exists. */ - static bool CacheExist(const std::string& cache_prefix); + static bool CacheExist(const std::string& cache_info); /*! \brief page size 32 MB */ static const size_t kPageSize = 32UL << 20UL; /*! \brief magic number used to identify Page */ diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index da5446570..38c46cfa3 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -224,16 +224,35 @@ class GBLinear : public GradientBooster { LOG(FATAL) << "gblinear does not support predict leaf index"; } - std::vector Dump2Text(const FeatureMap& fmap, int option) const override { + std::vector DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const override { std::stringstream fo(""); - fo << "bias:\n"; - for (int i = 0; i < model.param.num_output_group; ++i) { - fo << model.bias()[i] << std::endl; - } - fo << "weight:\n"; - for (int i = 0; i < model.param.num_output_group; ++i) { - for (unsigned j = 0; j v; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 6d7d79a94..a9b721812 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -64,7 +64,7 @@ struct DartTrainParam : public dmlc::Parameter { // declare parameters DMLC_DECLARE_PARAMETER(DartTrainParam) { DMLC_DECLARE_FIELD(silent).set_default(false) - .describe("Not print information during trainig."); + .describe("Not print information during training."); DMLC_DECLARE_FIELD(sample_type).set_default(0) .add_enum("uniform", 0) .add_enum("weighted", 1) @@ -275,10 +275,12 @@ class GBTree : public GradientBooster { this->PredPath(p_fmat, out_preds, ntree_limit); } - std::vector Dump2Text(const FeatureMap& fmap, int option) const override { + std::vector DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const override { std::vector dump; for (size_t i = 0; i < trees.size(); i++) { - dump.push_back(trees[i]->Dump2Text(fmap, option & 1)); + dump.push_back(trees[i]->DumpModel(fmap, with_stats, format)); } return dump; } diff --git a/src/learner.cc b/src/learner.cc index cdf74c353..b7325ab11 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -25,8 +25,10 @@ bool Learner::AllowLazyCheckPoint() const { } std::vector -Learner::Dump2Text(const FeatureMap& fmap, int option) const { - return gbm_->Dump2Text(fmap, option); +Learner::DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const { + return gbm_->DumpModel(fmap, with_stats, format); } diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 06fb0055b..559560fba 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -15,19 +15,33 @@ namespace tree { DMLC_REGISTER_PARAMETER(TrainParam); } // internal function to dump regression tree to text -void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*) - const RegTree& tree, - const FeatureMap& fmap, - int nid, int depth, bool with_stats) { - for (int i = 0; i < depth; ++i) { - fo << '\t'; +void DumpRegTree(std::stringstream& fo, // NOLINT(*) + const RegTree& tree, + const FeatureMap& fmap, + int nid, int depth, int add_comma, + bool with_stats, std::string format) { + if (format == "json") { + if (add_comma) fo << ","; + if (depth != 0) fo << std::endl; + for (int i = 0; i < depth+1; ++i) fo << " "; + } else { + for (int i = 0; i < depth; ++i) fo << '\t'; } if (tree[nid].is_leaf()) { - fo << nid << ":leaf=" << tree[nid].leaf_value(); - if (with_stats) { - fo << ",cover=" << tree.stat(nid).sum_hess; + if (format == "json") { + fo << "{ \"nodeid\": " << nid + << ", \"leaf\": " << tree[nid].leaf_value(); + if (with_stats) { + fo << ", \"cover\": " << tree.stat(nid).sum_hess; + } + fo << " }"; + } else { + fo << nid << ":leaf=" << tree[nid].leaf_value(); + if (with_stats) { + fo << ",cover=" << tree.stat(nid).sum_hess; + } + fo << '\n'; } - fo << '\n'; } else { // right then left, bst_float cond = tree[nid].split_cond(); @@ -37,47 +51,101 @@ void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*) case FeatureMap::kIndicator: { int nyes = tree[nid].default_left() ? tree[nid].cright() : tree[nid].cleft(); - fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes - << ",no=" << tree[nid].cdefault(); + if (format == "json") { + fo << "{ \"nodeid\": " << nid + << ", \"depth\": " << depth + << ", \"split\": \"" << fmap.name(split_index) << "\"" + << ", \"yes\": " << nyes + << ", \"no\": " << tree[nid].cdefault(); + } else { + fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes + << ",no=" << tree[nid].cdefault(); + } break; } case FeatureMap::kInteger: { - fo << nid << ":[" << fmap.name(split_index) << "<" - << int(float(cond)+1.0f) - << "] yes=" << tree[nid].cleft() - << ",no=" << tree[nid].cright() - << ",missing=" << tree[nid].cdefault(); + if (format == "json") { + fo << "{ \"nodeid\": " << nid + << ", \"depth\": " << depth + << ", \"split\": \"" << fmap.name(split_index) << "\"" + << ", \"split_condition\": " << int(float(cond) + 1.0f) + << ", \"yes\": " << tree[nid].cleft() + << ", \"no\": " << tree[nid].cright() + << ", \"missing\": " << tree[nid].cdefault(); + } else { + fo << nid << ":[" << fmap.name(split_index) << "<" + << int(float(cond)+1.0f) + << "] yes=" << tree[nid].cleft() + << ",no=" << tree[nid].cright() + << ",missing=" << tree[nid].cdefault(); + } break; } case FeatureMap::kFloat: case FeatureMap::kQuantitive: { - fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond) - << "] yes=" << tree[nid].cleft() - << ",no=" << tree[nid].cright() - << ",missing=" << tree[nid].cdefault(); - break; + if (format == "json") { + fo << "{ \"nodeid\": " << nid + << ", \"depth\": " << depth + << ", \"split\": \"" << fmap.name(split_index) << "\"" + << ", \"split_condition\": " << float(cond) + << ", \"yes\": " << tree[nid].cleft() + << ", \"no\": " << tree[nid].cright() + << ", \"missing\": " << tree[nid].cdefault(); + } else { + fo << nid << ":[" << fmap.name(split_index) << "<" << float(cond) + << "] yes=" << tree[nid].cleft() + << ",no=" << tree[nid].cright() + << ",missing=" << tree[nid].cdefault(); + } + break; } default: LOG(FATAL) << "unknown fmap type"; } } else { - fo << nid << ":[f" << split_index << "<"<< float(cond) - << "] yes=" << tree[nid].cleft() - << ",no=" << tree[nid].cright() - << ",missing=" << tree[nid].cdefault(); + if (format == "json") { + fo << "{ \"nodeid\": " << nid + << ", \"depth\": " << depth + << ", \"split\": " << split_index + << ", \"split_condition\": " << float(cond) + << ", \"yes\": " << tree[nid].cleft() + << ", \"no\": " << tree[nid].cright() + << ", \"missing\": " << tree[nid].cdefault(); + } else { + fo << nid << ":[f" << split_index << "<"<< float(cond) + << "] yes=" << tree[nid].cleft() + << ",no=" << tree[nid].cright() + << ",missing=" << tree[nid].cdefault(); + } } if (with_stats) { - fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess; + if (format == "json") { + fo << ", \"gain\": " << tree.stat(nid).loss_chg + << ", \"cover\": " << tree.stat(nid).sum_hess; + } else { + fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess; + } + } + if (format == "json") { + fo << ", \"children\": ["; + } else { + fo << '\n'; + } + DumpRegTree(fo, tree, fmap, tree[nid].cleft(), depth + 1, false, with_stats, format); + DumpRegTree(fo, tree, fmap, tree[nid].cright(), depth + 1, true, with_stats, format); + if (format == "json") { + fo << std::endl; + for (int i = 0; i < depth+1; ++i) fo << " "; + fo << "]}"; } - fo << '\n'; - DumpRegTree2Text(fo, tree, fmap, tree[nid].cleft(), depth + 1, with_stats); - DumpRegTree2Text(fo, tree, fmap, tree[nid].cright(), depth + 1, with_stats); } } -std::string RegTree::Dump2Text(const FeatureMap& fmap, bool with_stats) const { +std::string RegTree::DumpModel(const FeatureMap& fmap, + bool with_stats, + std::string format) const { std::stringstream fo(""); for (int i = 0; i < param.num_roots; ++i) { - DumpRegTree2Text(fo, *this, fmap, i, 0, with_stats); + DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format); } return fo.str(); } diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 710de987d..3314060df 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -2,6 +2,7 @@ import numpy as np import xgboost as xgb import unittest +import json dpath = 'demo/data/' rng = np.random.RandomState(1994) @@ -170,6 +171,39 @@ class TestBasic(unittest.TestCase): fscores = bst.get_fscore() assert scores1 == fscores + def test_dump(self): + data = np.random.randn(100, 2) + target = np.array([0, 1] * 50) + features = ['Feature1', 'Feature2'] + + dm = xgb.DMatrix(data, label=target, feature_names=features) + params = {'objective': 'binary:logistic', + 'eval_metric': 'logloss', + 'eta': 0.3, + 'max_depth': 1} + + bst = xgb.train(params, dm, num_boost_round=1) + + # number of feature importances should == number of features + dump1 = bst.get_dump() + self.assertEqual(len(dump1), 1, "Expected only 1 tree to be dumped.") + self.assertEqual(len(dump1[0].splitlines()), 3, + "Expected 1 root and 2 leaves - 3 lines in dump.") + + dump2 = bst.get_dump(with_stats=True) + self.assertEqual(dump2[0].count('\n'), 3, + "Expected 1 root and 2 leaves - 3 lines in dump.") + self.assertGreater(dump2[0].find('\n'), dump1[0].find('\n'), + "Expected more info when with_stats=True is given.") + + dump3 = bst.get_dump(dump_format="json") + dump3j = json.loads(dump3[0]) + self.assertEqual(dump3j["nodeid"], 0, "Expected the root node on top.") + + dump4 = bst.get_dump(dump_format="json", with_stats=True) + dump4j = json.loads(dump4[0]) + self.assertIn("gain", dump4j, "Expected 'gain' to be dumped in JSON.") + def test_load_file_invalid(self): self.assertRaises(xgb.core.XGBoostError, xgb.Booster, model_file='incorrect_path')