Add dump_format=json option (#1726)
* Add format to the params accepted by DumpModel Currently, only the test format is supported when trying to dump a model. The plan is to add more such formats like JSON which are easy to read and/or parse by machines. And to make the interface for this even more generic to allow other formats to be added. Hence, we make some modifications to make these function generic and accept a new parameter "format" which signifies the format of the dump to be created. * Fix typos and errors in docs * plugin: Mention all the register macros available Document the register macros currently available to the plugin writers so they know what exactly can be extended using hooks. * sparce_page_source: Use same arg name in .h and .cc * gbm: Add JSON dump The dump_format argument can be used to specify what type of dump file should be created. Add functionality to dump gblinear and gbtree into a JSON file. The JSON file has an array, each item is a JSON object for the tree. For gblinear: - The item is the bias and weights vectors For gbtree: - The item is the root node. The root node has a attribute "children" which holds the children nodes. This happens recursively. * core.py: Add arg dump_format for get_dump()
This commit is contained in:
parent
9c693f0f5f
commit
b94fcab4dc
@ -446,6 +446,23 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_dump_array);
|
||||
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
* \param fmap name to fmap can be empty string
|
||||
* \param with_stats whether to dump with statistics
|
||||
* \param format the format to dump the model in
|
||||
* \param out_len length of output array
|
||||
* \param out_dump_array pointer to hold representing dump of each model
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
||||
const char *fmap,
|
||||
int with_stats,
|
||||
const char *format,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_dump_array);
|
||||
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
@ -465,6 +482,27 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_models);
|
||||
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
* \param fnum number of features
|
||||
* \param fname names of features
|
||||
* \param ftype types of features
|
||||
* \param with_stats whether to dump with statistics
|
||||
* \param format the format to dump the model in
|
||||
* \param out_len length of output array
|
||||
* \param out_models pointer to hold representing dump of each model
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char **fname,
|
||||
const char **ftype,
|
||||
int with_stats,
|
||||
const char *format,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_models);
|
||||
|
||||
/*!
|
||||
* \brief Get string attribute from Booster.
|
||||
* \param handle handle
|
||||
|
||||
@ -225,7 +225,7 @@ struct RowSet {
|
||||
* - Provide a dmlc::Parser and pass into the DMatrix::Create
|
||||
* - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER;
|
||||
* - This works best for user defined data input source, such as data-base, filesystem.
|
||||
* - Provdie a DataSource, that can be passed to DMatrix::Create
|
||||
* - Provide a DataSource, that can be passed to DMatrix::Create
|
||||
* This can be used to re-use inmemory data structure into DMatrix.
|
||||
*/
|
||||
class DMatrix {
|
||||
|
||||
@ -108,12 +108,15 @@ class GradientBooster {
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
/*!
|
||||
* \brief dump the model to text format
|
||||
* \brief dump the model in the requested format
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
* \param option extra option of the dump model
|
||||
* \param with_stats extra statistics while dumping model
|
||||
* \param format the format to dump the model in
|
||||
* \return a vector of dump for boosters.
|
||||
*/
|
||||
virtual std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const = 0;
|
||||
virtual std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const = 0;
|
||||
/*!
|
||||
* \brief create a gradient booster from given name
|
||||
* \param name name of gradient booster
|
||||
|
||||
@ -140,12 +140,15 @@ class Learner : public rabit::Serializable {
|
||||
*/
|
||||
bool AllowLazyCheckPoint() const;
|
||||
/*!
|
||||
* \brief dump the model in text format
|
||||
* \brief dump the model in the requested format
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
* \param option extra option of the dump model
|
||||
* \param with_stats extra statistics while dumping model
|
||||
* \param format the format to dump the model in
|
||||
* \return a vector of dump for boosters.
|
||||
*/
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const;
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const;
|
||||
/*!
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is usually
|
||||
|
||||
@ -480,12 +480,15 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
|
||||
*/
|
||||
inline int GetNext(int pid, float fvalue, bool is_unknown) const;
|
||||
/*!
|
||||
* \brief dump model to text string
|
||||
* \param fmap feature map of feature types
|
||||
* \brief dump the model in the requested format as a text string
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
* \param with_stats whether dump out statistics as well
|
||||
* \param format the format to dump the model in
|
||||
* \return the string of dumped model
|
||||
*/
|
||||
std::string Dump2Text(const FeatureMap& fmap, bool with_stats) const;
|
||||
std::string DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const;
|
||||
};
|
||||
|
||||
// implementations of inline functions
|
||||
|
||||
@ -2,16 +2,17 @@ XGBoost Plugins Modules
|
||||
=======================
|
||||
This folder contains plugin modules to xgboost that can be optionally installed.
|
||||
The plugin system helps us to extend xgboost with additional features,
|
||||
and add experimental features that may not yet ready to be included in main project.
|
||||
and add experimental features that may not yet be ready to be included in the
|
||||
main project.
|
||||
|
||||
To include a certain plugin, say ```plugin_a```, you only need to add the following line to the config.mk.
|
||||
|
||||
```makefile
|
||||
# Add plugin by include the plugin in config
|
||||
# Add plugin by including the plugin in config.mk
|
||||
XGB_PLUGINS += plugin/plugin_a/plugin.mk
|
||||
```
|
||||
|
||||
Then rebuild libxgboost by typing make, you can get a new library with the plugin enabled.
|
||||
Then rebuild libxgboost by typing ```make```, you can get a new library with the plugin enabled.
|
||||
|
||||
Link Static XGBoost Library with Plugins
|
||||
----------------------------------------
|
||||
@ -20,7 +21,7 @@ If you only use ```libxgboost.so```(this include python and other bindings),
|
||||
you can ignore this section.
|
||||
|
||||
When you want to link ```libxgboost.a``` with additional plugins included,
|
||||
you will need to enabled whole archeive via The following option.
|
||||
you will need to enabled whole archive via The following option.
|
||||
```bash
|
||||
--whole-archive libxgboost.a --no-whole-archive
|
||||
```
|
||||
@ -30,3 +31,21 @@ Write Your Own Plugin
|
||||
You can plugin your own modules to xgboost by adding code to this folder,
|
||||
without modification to the main code repo.
|
||||
The [example](example) folder provides an example to write a plugin.
|
||||
|
||||
List of register functions
|
||||
--------------------------
|
||||
A plugin has to register a new functionality to xgboost to be able to use it.
|
||||
The register macros available to plugin writers are:
|
||||
|
||||
- XGBOOST_REGISTER_METRIC - Register an evaluation metric
|
||||
- XGBOOST_REGISTER_GBM - Register a new gradient booster that learns through
|
||||
gradient statistics
|
||||
- XGBOOST_REGISTER_OBJECTIVE - Register a new objective function used by xgboost
|
||||
- XGBOOST_REGISTER_TREE_UPDATER - Register a new tree-updater which updates
|
||||
the tree given the gradient information
|
||||
|
||||
And from dmlc-core:
|
||||
|
||||
- DMLC_REGISTER_PARAMETER - Register a set of parameter for a specific usecase
|
||||
- DMLC_REGISTER_DATA_PARSER - Register a data parser where the data can be
|
||||
represented by a URL. This is used by DMatrix.
|
||||
|
||||
@ -2,10 +2,10 @@ XGBoost Plugin Example
|
||||
======================
|
||||
This folder provides an example of xgboost plugin.
|
||||
|
||||
There are three steps you need to to do to add plugin to xgboost
|
||||
There are three steps you need to do to add a plugin to xgboost
|
||||
- Create your source .cc file, implement a new extension
|
||||
- In this example [custom_obj.cc](custom_obj.cc)
|
||||
- Register this extension to xgboost via registration macr
|
||||
- Register this extension to xgboost via a registration macro
|
||||
- In this example ```XGBOOST_REGISTER_OBJECTIVE``` in [this line](custom_obj.cc#L75)
|
||||
- Create a [plugin.mk](plugin.mk) on this folder
|
||||
|
||||
|
||||
@ -1038,7 +1038,7 @@ class Booster(object):
|
||||
if need_close:
|
||||
fout.close()
|
||||
|
||||
def get_dump(self, fmap='', with_stats=False):
|
||||
def get_dump(self, fmap='', with_stats=False, dump_format="text"):
|
||||
"""
|
||||
Returns the dump the model as a list of strings.
|
||||
"""
|
||||
@ -1056,19 +1056,22 @@ class Booster(object):
|
||||
ftype = from_pystr_to_cstr(['q'] * flen)
|
||||
else:
|
||||
ftype = from_pystr_to_cstr(self.feature_types)
|
||||
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
||||
_check_call(_LIB.XGBoosterDumpModelExWithFeatures(
|
||||
self.handle,
|
||||
flen,
|
||||
fname,
|
||||
ftype,
|
||||
int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
else:
|
||||
if fmap != '' and not os.path.exists(fmap):
|
||||
raise ValueError("No such file: {0}".format(fmap))
|
||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||
_check_call(_LIB.XGBoosterDumpModelEx(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
res = from_cstr_to_pystr(sarr, length)
|
||||
|
||||
@ -662,13 +662,14 @@ inline void XGBoostDumpModelImpl(
|
||||
BoosterHandle handle,
|
||||
const FeatureMap& fmap,
|
||||
int with_stats,
|
||||
const char *format,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
str_vecs = bst->learner()->Dump2Text(fmap, with_stats != 0);
|
||||
str_vecs = bst->learner()->DumpModel(fmap, with_stats != 0, format);
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
@ -681,6 +682,14 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
||||
int with_stats,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
|
||||
}
|
||||
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
||||
const char* fmap,
|
||||
int with_stats,
|
||||
const char *format,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
API_BEGIN();
|
||||
FeatureMap featmap;
|
||||
if (strlen(fmap) != 0) {
|
||||
@ -689,7 +698,7 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
||||
dmlc::istream is(fs.get());
|
||||
featmap.LoadText(is);
|
||||
}
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -700,12 +709,23 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
int with_stats,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
|
||||
"text", len, out_models);
|
||||
}
|
||||
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char** fname,
|
||||
const char** ftype,
|
||||
int with_stats,
|
||||
const char *format,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
API_BEGIN();
|
||||
FeatureMap featmap;
|
||||
for (int i = 0; i < fnum; ++i) {
|
||||
featmap.PushBack(i, fname[i], ftype[i]);
|
||||
}
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ namespace xgboost {
|
||||
|
||||
enum CLITask {
|
||||
kTrain = 0,
|
||||
kDump2Text = 1,
|
||||
kDumpModel = 1,
|
||||
kPredict = 2
|
||||
};
|
||||
|
||||
@ -62,6 +62,8 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
bool pred_margin;
|
||||
/*! \brief whether dump statistics along with model */
|
||||
int dump_stats;
|
||||
/*! \brief what format to dump the model in */
|
||||
std::string dump_format;
|
||||
/*! \brief name of feature map */
|
||||
std::string name_fmap;
|
||||
/*! \brief name of dump file */
|
||||
@ -78,7 +80,7 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
// NOTE: declare everything except eval_data_paths.
|
||||
DMLC_DECLARE_FIELD(task).set_default(kTrain)
|
||||
.add_enum("train", kTrain)
|
||||
.add_enum("dump", kDump2Text)
|
||||
.add_enum("dump", kDumpModel)
|
||||
.add_enum("pred", kPredict)
|
||||
.describe("Task to be performed by the CLI program.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(0).set_range(0, 2)
|
||||
@ -112,6 +114,8 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
.describe("Whether to predict margin value instead of probability.");
|
||||
DMLC_DECLARE_FIELD(dump_stats).set_default(false)
|
||||
.describe("Whether dump the model statistics.");
|
||||
DMLC_DECLARE_FIELD(dump_format).set_default("text")
|
||||
.describe("What format to dump the model in.");
|
||||
DMLC_DECLARE_FIELD(name_fmap).set_default("NULL")
|
||||
.describe("Name of the feature map file.");
|
||||
DMLC_DECLARE_FIELD(name_dump).set_default("dump.txt")
|
||||
@ -259,7 +263,7 @@ void CLITrain(const CLIParam& param) {
|
||||
}
|
||||
}
|
||||
|
||||
void CLIDump2Text(const CLIParam& param) {
|
||||
void CLIDumpModel(const CLIParam& param) {
|
||||
FeatureMap fmap;
|
||||
if (param.name_fmap != "NULL") {
|
||||
std::unique_ptr<dmlc::Stream> fs(
|
||||
@ -276,14 +280,24 @@ void CLIDump2Text(const CLIParam& param) {
|
||||
learner->Configure(param.cfg);
|
||||
learner->Load(fi.get());
|
||||
// dump data
|
||||
std::vector<std::string> dump = learner->Dump2Text(fmap, param.dump_stats);
|
||||
std::vector<std::string> dump = learner->DumpModel(
|
||||
fmap, param.dump_stats, param.dump_format);
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param.name_dump.c_str(), "w"));
|
||||
dmlc::ostream os(fo.get());
|
||||
if (param.dump_format == "json") {
|
||||
os << "[" << std::endl;
|
||||
for (size_t i = 0; i < dump.size(); ++i) {
|
||||
if (i != 0) os << "," << std::endl;
|
||||
os << dump[i]; // Dump the previously generated JSON here
|
||||
}
|
||||
os << std::endl << "]" << std::endl;
|
||||
} else {
|
||||
for (size_t i = 0; i < dump.size(); ++i) {
|
||||
os << "booster[" << i << "]:\n";
|
||||
os << dump[i];
|
||||
}
|
||||
}
|
||||
// force flush before fo destruct.
|
||||
os.set_stream(nullptr);
|
||||
}
|
||||
@ -347,7 +361,7 @@ int CLIRunTask(int argc, char *argv[]) {
|
||||
|
||||
switch (param.task) {
|
||||
case kTrain: CLITrain(param); break;
|
||||
case kDump2Text: CLIDump2Text(param); break;
|
||||
case kDumpModel: CLIDumpModel(param); break;
|
||||
case kPredict: CLIPredict(param); break;
|
||||
}
|
||||
rabit::Finalize();
|
||||
|
||||
@ -43,22 +43,22 @@ class SparsePageSource : public DataSource {
|
||||
/*!
|
||||
* \brief Create source by taking data from parser.
|
||||
* \param src source parser.
|
||||
* \param cache_prefix The cache_prefix of cache file location.
|
||||
* \param cache_info The cache_info of cache file location.
|
||||
*/
|
||||
static void Create(dmlc::Parser<uint32_t>* src,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_info);
|
||||
/*!
|
||||
* \brief Create source cache by copy content from DMatrix.
|
||||
* \param cache_prefix The cache_prefix of cache file location.
|
||||
* \param cache_info The cache_info of cache file location.
|
||||
*/
|
||||
static void Create(DMatrix* src,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_info);
|
||||
/*!
|
||||
* \brief Check if the cache file already exists.
|
||||
* \param cache_prefix The cache prefix of files.
|
||||
* \param cache_info The cache prefix of files.
|
||||
* \return Whether cache file already exists.
|
||||
*/
|
||||
static bool CacheExist(const std::string& cache_prefix);
|
||||
static bool CacheExist(const std::string& cache_info);
|
||||
/*! \brief page size 32 MB */
|
||||
static const size_t kPageSize = 32UL << 20UL;
|
||||
/*! \brief magic number used to identify Page */
|
||||
|
||||
@ -224,8 +224,26 @@ class GBLinear : public GradientBooster {
|
||||
LOG(FATAL) << "gblinear does not support predict leaf index";
|
||||
}
|
||||
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const override {
|
||||
std::stringstream fo("");
|
||||
if (format == "json") {
|
||||
fo << " { \"bias\": [" << std::endl;
|
||||
for (int i = 0; i < model.param.num_output_group; ++i) {
|
||||
if (i != 0) fo << "," << std::endl;
|
||||
fo << " " << model.bias()[i];
|
||||
}
|
||||
fo << std::endl << " ]," << std::endl
|
||||
<< " \"weight\": [" << std::endl;
|
||||
for (int i = 0; i < model.param.num_output_group; ++i) {
|
||||
for (unsigned j = 0; j < model.param.num_feature; ++j) {
|
||||
if (i != 0 || j != 0) fo << "," << std::endl;
|
||||
fo << " " << model[i][j];
|
||||
}
|
||||
}
|
||||
fo << std::endl << " ]" << std::endl << " }";
|
||||
} else {
|
||||
fo << "bias:\n";
|
||||
for (int i = 0; i < model.param.num_output_group; ++i) {
|
||||
fo << model.bias()[i] << std::endl;
|
||||
@ -236,6 +254,7 @@ class GBLinear : public GradientBooster {
|
||||
fo << model[i][j] << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::string> v;
|
||||
v.push_back(fo.str());
|
||||
return v;
|
||||
|
||||
@ -64,7 +64,7 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(DartTrainParam) {
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false)
|
||||
.describe("Not print information during trainig.");
|
||||
.describe("Not print information during training.");
|
||||
DMLC_DECLARE_FIELD(sample_type).set_default(0)
|
||||
.add_enum("uniform", 0)
|
||||
.add_enum("weighted", 1)
|
||||
@ -275,10 +275,12 @@ class GBTree : public GradientBooster {
|
||||
this->PredPath(p_fmat, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const override {
|
||||
std::vector<std::string> dump;
|
||||
for (size_t i = 0; i < trees.size(); i++) {
|
||||
dump.push_back(trees[i]->Dump2Text(fmap, option & 1));
|
||||
dump.push_back(trees[i]->DumpModel(fmap, with_stats, format));
|
||||
}
|
||||
return dump;
|
||||
}
|
||||
|
||||
@ -25,8 +25,10 @@ bool Learner::AllowLazyCheckPoint() const {
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
Learner::Dump2Text(const FeatureMap& fmap, int option) const {
|
||||
return gbm_->Dump2Text(fmap, option);
|
||||
Learner::DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const {
|
||||
return gbm_->DumpModel(fmap, with_stats, format);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -15,19 +15,33 @@ namespace tree {
|
||||
DMLC_REGISTER_PARAMETER(TrainParam);
|
||||
}
|
||||
// internal function to dump regression tree to text
|
||||
void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
|
||||
void DumpRegTree(std::stringstream& fo, // NOLINT(*)
|
||||
const RegTree& tree,
|
||||
const FeatureMap& fmap,
|
||||
int nid, int depth, bool with_stats) {
|
||||
for (int i = 0; i < depth; ++i) {
|
||||
fo << '\t';
|
||||
int nid, int depth, int add_comma,
|
||||
bool with_stats, std::string format) {
|
||||
if (format == "json") {
|
||||
if (add_comma) fo << ",";
|
||||
if (depth != 0) fo << std::endl;
|
||||
for (int i = 0; i < depth+1; ++i) fo << " ";
|
||||
} else {
|
||||
for (int i = 0; i < depth; ++i) fo << '\t';
|
||||
}
|
||||
if (tree[nid].is_leaf()) {
|
||||
if (format == "json") {
|
||||
fo << "{ \"nodeid\": " << nid
|
||||
<< ", \"leaf\": " << tree[nid].leaf_value();
|
||||
if (with_stats) {
|
||||
fo << ", \"cover\": " << tree.stat(nid).sum_hess;
|
||||
}
|
||||
fo << " }";
|
||||
} else {
|
||||
fo << nid << ":leaf=" << tree[nid].leaf_value();
|
||||
if (with_stats) {
|
||||
fo << ",cover=" << tree.stat(nid).sum_hess;
|
||||
}
|
||||
fo << '\n';
|
||||
}
|
||||
} else {
|
||||
// right then left,
|
||||
bst_float cond = tree[nid].split_cond();
|
||||
@ -37,47 +51,101 @@ void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
|
||||
case FeatureMap::kIndicator: {
|
||||
int nyes = tree[nid].default_left() ?
|
||||
tree[nid].cright() : tree[nid].cleft();
|
||||
if (format == "json") {
|
||||
fo << "{ \"nodeid\": " << nid
|
||||
<< ", \"depth\": " << depth
|
||||
<< ", \"split\": \"" << fmap.name(split_index) << "\""
|
||||
<< ", \"yes\": " << nyes
|
||||
<< ", \"no\": " << tree[nid].cdefault();
|
||||
} else {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
|
||||
<< ",no=" << tree[nid].cdefault();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kInteger: {
|
||||
if (format == "json") {
|
||||
fo << "{ \"nodeid\": " << nid
|
||||
<< ", \"depth\": " << depth
|
||||
<< ", \"split\": \"" << fmap.name(split_index) << "\""
|
||||
<< ", \"split_condition\": " << int(float(cond) + 1.0f)
|
||||
<< ", \"yes\": " << tree[nid].cleft()
|
||||
<< ", \"no\": " << tree[nid].cright()
|
||||
<< ", \"missing\": " << tree[nid].cdefault();
|
||||
} else {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "<"
|
||||
<< int(float(cond)+1.0f)
|
||||
<< "] yes=" << tree[nid].cleft()
|
||||
<< ",no=" << tree[nid].cright()
|
||||
<< ",missing=" << tree[nid].cdefault();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case FeatureMap::kFloat:
|
||||
case FeatureMap::kQuantitive: {
|
||||
if (format == "json") {
|
||||
fo << "{ \"nodeid\": " << nid
|
||||
<< ", \"depth\": " << depth
|
||||
<< ", \"split\": \"" << fmap.name(split_index) << "\""
|
||||
<< ", \"split_condition\": " << float(cond)
|
||||
<< ", \"yes\": " << tree[nid].cleft()
|
||||
<< ", \"no\": " << tree[nid].cright()
|
||||
<< ", \"missing\": " << tree[nid].cdefault();
|
||||
} else {
|
||||
fo << nid << ":[" << fmap.name(split_index) << "<" << float(cond)
|
||||
<< "] yes=" << tree[nid].cleft()
|
||||
<< ",no=" << tree[nid].cright()
|
||||
<< ",missing=" << tree[nid].cdefault();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: LOG(FATAL) << "unknown fmap type";
|
||||
}
|
||||
} else {
|
||||
if (format == "json") {
|
||||
fo << "{ \"nodeid\": " << nid
|
||||
<< ", \"depth\": " << depth
|
||||
<< ", \"split\": " << split_index
|
||||
<< ", \"split_condition\": " << float(cond)
|
||||
<< ", \"yes\": " << tree[nid].cleft()
|
||||
<< ", \"no\": " << tree[nid].cright()
|
||||
<< ", \"missing\": " << tree[nid].cdefault();
|
||||
} else {
|
||||
fo << nid << ":[f" << split_index << "<"<< float(cond)
|
||||
<< "] yes=" << tree[nid].cleft()
|
||||
<< ",no=" << tree[nid].cright()
|
||||
<< ",missing=" << tree[nid].cdefault();
|
||||
}
|
||||
}
|
||||
if (with_stats) {
|
||||
if (format == "json") {
|
||||
fo << ", \"gain\": " << tree.stat(nid).loss_chg
|
||||
<< ", \"cover\": " << tree.stat(nid).sum_hess;
|
||||
} else {
|
||||
fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess;
|
||||
}
|
||||
}
|
||||
if (format == "json") {
|
||||
fo << ", \"children\": [";
|
||||
} else {
|
||||
fo << '\n';
|
||||
DumpRegTree2Text(fo, tree, fmap, tree[nid].cleft(), depth + 1, with_stats);
|
||||
DumpRegTree2Text(fo, tree, fmap, tree[nid].cright(), depth + 1, with_stats);
|
||||
}
|
||||
DumpRegTree(fo, tree, fmap, tree[nid].cleft(), depth + 1, false, with_stats, format);
|
||||
DumpRegTree(fo, tree, fmap, tree[nid].cright(), depth + 1, true, with_stats, format);
|
||||
if (format == "json") {
|
||||
fo << std::endl;
|
||||
for (int i = 0; i < depth+1; ++i) fo << " ";
|
||||
fo << "]}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string RegTree::Dump2Text(const FeatureMap& fmap, bool with_stats) const {
|
||||
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const {
|
||||
std::stringstream fo("");
|
||||
for (int i = 0; i < param.num_roots; ++i) {
|
||||
DumpRegTree2Text(fo, *this, fmap, i, 0, with_stats);
|
||||
DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format);
|
||||
}
|
||||
return fo.str();
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import json
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
@ -170,6 +171,39 @@ class TestBasic(unittest.TestCase):
|
||||
fscores = bst.get_fscore()
|
||||
assert scores1 == fscores
|
||||
|
||||
def test_dump(self):
|
||||
data = np.random.randn(100, 2)
|
||||
target = np.array([0, 1] * 50)
|
||||
features = ['Feature1', 'Feature2']
|
||||
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=features)
|
||||
params = {'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'eta': 0.3,
|
||||
'max_depth': 1}
|
||||
|
||||
bst = xgb.train(params, dm, num_boost_round=1)
|
||||
|
||||
# number of feature importances should == number of features
|
||||
dump1 = bst.get_dump()
|
||||
self.assertEqual(len(dump1), 1, "Expected only 1 tree to be dumped.")
|
||||
self.assertEqual(len(dump1[0].splitlines()), 3,
|
||||
"Expected 1 root and 2 leaves - 3 lines in dump.")
|
||||
|
||||
dump2 = bst.get_dump(with_stats=True)
|
||||
self.assertEqual(dump2[0].count('\n'), 3,
|
||||
"Expected 1 root and 2 leaves - 3 lines in dump.")
|
||||
self.assertGreater(dump2[0].find('\n'), dump1[0].find('\n'),
|
||||
"Expected more info when with_stats=True is given.")
|
||||
|
||||
dump3 = bst.get_dump(dump_format="json")
|
||||
dump3j = json.loads(dump3[0])
|
||||
self.assertEqual(dump3j["nodeid"], 0, "Expected the root node on top.")
|
||||
|
||||
dump4 = bst.get_dump(dump_format="json", with_stats=True)
|
||||
dump4j = json.loads(dump4[0])
|
||||
self.assertIn("gain", dump4j, "Expected 'gain' to be dumped in JSON.")
|
||||
|
||||
def test_load_file_invalid(self):
|
||||
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
|
||||
model_file='incorrect_path')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user