Add dump_format=json option (#1726)

* Add format to the params accepted by DumpModel

Currently, only the test format is supported when trying to dump
a model. The plan is to add more such formats like JSON which are
easy to read and/or parse by machines. And to make the interface
for this even more generic to allow other formats to be added.

Hence, we make some modifications to make these function generic
and accept a new parameter "format" which signifies the format of
the dump to be created.

* Fix typos and errors in docs

* plugin: Mention all the register macros available

Document the register macros currently available to the plugin
writers so they know what exactly can be extended using hooks.

* sparce_page_source: Use same arg name in .h and .cc

* gbm: Add JSON dump

The dump_format argument can be used to specify what type
of dump file should be created. Add functionality to dump
gblinear and gbtree into a JSON file.

The JSON file has an array, each item is a JSON object for the tree.
For gblinear:
 - The item is the bias and weights vectors
For gbtree:
 - The item is the root node. The root node has a attribute "children"
   which holds the children nodes. This happens recursively.

* core.py: Add arg dump_format for get_dump()
This commit is contained in:
AbdealiJK
2016-11-04 22:25:25 +05:30
committed by Tianqi Chen
parent 9c693f0f5f
commit b94fcab4dc
16 changed files with 320 additions and 92 deletions

View File

@@ -662,13 +662,14 @@ inline void XGBoostDumpModelImpl(
BoosterHandle handle,
const FeatureMap& fmap,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
Booster *bst = static_cast<Booster*>(handle);
bst->LazyInit();
str_vecs = bst->learner()->Dump2Text(fmap, with_stats != 0);
str_vecs = bst->learner()->DumpModel(fmap, with_stats != 0, format);
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {
charp_vecs[i] = str_vecs[i].c_str();
@@ -681,6 +682,14 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
}
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
const char* fmap,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
FeatureMap featmap;
if (strlen(fmap) != 0) {
@@ -689,7 +698,7 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
dmlc::istream is(fs.get());
featmap.LoadText(is);
}
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models);
API_END();
}
@@ -700,12 +709,23 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int with_stats,
xgboost::bst_ulong* len,
const char*** out_models) {
XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
"text", len, out_models);
}
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
int fnum,
const char** fname,
const char** ftype,
int with_stats,
const char *format,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
FeatureMap featmap;
for (int i = 0; i < fnum; ++i) {
featmap.PushBack(i, fname[i], ftype[i]);
}
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models);
API_END();
}

View File

@@ -27,7 +27,7 @@ namespace xgboost {
enum CLITask {
kTrain = 0,
kDump2Text = 1,
kDumpModel = 1,
kPredict = 2
};
@@ -62,6 +62,8 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
bool pred_margin;
/*! \brief whether dump statistics along with model */
int dump_stats;
/*! \brief what format to dump the model in */
std::string dump_format;
/*! \brief name of feature map */
std::string name_fmap;
/*! \brief name of dump file */
@@ -78,7 +80,7 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
// NOTE: declare everything except eval_data_paths.
DMLC_DECLARE_FIELD(task).set_default(kTrain)
.add_enum("train", kTrain)
.add_enum("dump", kDump2Text)
.add_enum("dump", kDumpModel)
.add_enum("pred", kPredict)
.describe("Task to be performed by the CLI program.");
DMLC_DECLARE_FIELD(silent).set_default(0).set_range(0, 2)
@@ -112,6 +114,8 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
.describe("Whether to predict margin value instead of probability.");
DMLC_DECLARE_FIELD(dump_stats).set_default(false)
.describe("Whether dump the model statistics.");
DMLC_DECLARE_FIELD(dump_format).set_default("text")
.describe("What format to dump the model in.");
DMLC_DECLARE_FIELD(name_fmap).set_default("NULL")
.describe("Name of the feature map file.");
DMLC_DECLARE_FIELD(name_dump).set_default("dump.txt")
@@ -259,7 +263,7 @@ void CLITrain(const CLIParam& param) {
}
}
void CLIDump2Text(const CLIParam& param) {
void CLIDumpModel(const CLIParam& param) {
FeatureMap fmap;
if (param.name_fmap != "NULL") {
std::unique_ptr<dmlc::Stream> fs(
@@ -276,13 +280,23 @@ void CLIDump2Text(const CLIParam& param) {
learner->Configure(param.cfg);
learner->Load(fi.get());
// dump data
std::vector<std::string> dump = learner->Dump2Text(fmap, param.dump_stats);
std::vector<std::string> dump = learner->DumpModel(
fmap, param.dump_stats, param.dump_format);
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(param.name_dump.c_str(), "w"));
dmlc::ostream os(fo.get());
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
if (param.dump_format == "json") {
os << "[" << std::endl;
for (size_t i = 0; i < dump.size(); ++i) {
if (i != 0) os << "," << std::endl;
os << dump[i]; // Dump the previously generated JSON here
}
os << std::endl << "]" << std::endl;
} else {
for (size_t i = 0; i < dump.size(); ++i) {
os << "booster[" << i << "]:\n";
os << dump[i];
}
}
// force flush before fo destruct.
os.set_stream(nullptr);
@@ -347,7 +361,7 @@ int CLIRunTask(int argc, char *argv[]) {
switch (param.task) {
case kTrain: CLITrain(param); break;
case kDump2Text: CLIDump2Text(param); break;
case kDumpModel: CLIDumpModel(param); break;
case kPredict: CLIPredict(param); break;
}
rabit::Finalize();

View File

@@ -43,22 +43,22 @@ class SparsePageSource : public DataSource {
/*!
* \brief Create source by taking data from parser.
* \param src source parser.
* \param cache_prefix The cache_prefix of cache file location.
* \param cache_info The cache_info of cache file location.
*/
static void Create(dmlc::Parser<uint32_t>* src,
const std::string& cache_prefix);
const std::string& cache_info);
/*!
* \brief Create source cache by copy content from DMatrix.
* \param cache_prefix The cache_prefix of cache file location.
* \param cache_info The cache_info of cache file location.
*/
static void Create(DMatrix* src,
const std::string& cache_prefix);
const std::string& cache_info);
/*!
* \brief Check if the cache file already exists.
* \param cache_prefix The cache prefix of files.
* \param cache_info The cache prefix of files.
* \return Whether cache file already exists.
*/
static bool CacheExist(const std::string& cache_prefix);
static bool CacheExist(const std::string& cache_info);
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;
/*! \brief magic number used to identify Page */

View File

@@ -224,16 +224,35 @@ class GBLinear : public GradientBooster {
LOG(FATAL) << "gblinear does not support predict leaf index";
}
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const override {
std::stringstream fo("");
fo << "bias:\n";
for (int i = 0; i < model.param.num_output_group; ++i) {
fo << model.bias()[i] << std::endl;
}
fo << "weight:\n";
for (int i = 0; i < model.param.num_output_group; ++i) {
for (unsigned j = 0; j <model.param.num_feature; ++j) {
fo << model[i][j] << std::endl;
if (format == "json") {
fo << " { \"bias\": [" << std::endl;
for (int i = 0; i < model.param.num_output_group; ++i) {
if (i != 0) fo << "," << std::endl;
fo << " " << model.bias()[i];
}
fo << std::endl << " ]," << std::endl
<< " \"weight\": [" << std::endl;
for (int i = 0; i < model.param.num_output_group; ++i) {
for (unsigned j = 0; j < model.param.num_feature; ++j) {
if (i != 0 || j != 0) fo << "," << std::endl;
fo << " " << model[i][j];
}
}
fo << std::endl << " ]" << std::endl << " }";
} else {
fo << "bias:\n";
for (int i = 0; i < model.param.num_output_group; ++i) {
fo << model.bias()[i] << std::endl;
}
fo << "weight:\n";
for (int i = 0; i < model.param.num_output_group; ++i) {
for (unsigned j = 0; j <model.param.num_feature; ++j) {
fo << model[i][j] << std::endl;
}
}
}
std::vector<std::string> v;

View File

@@ -64,7 +64,7 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
// declare parameters
DMLC_DECLARE_PARAMETER(DartTrainParam) {
DMLC_DECLARE_FIELD(silent).set_default(false)
.describe("Not print information during trainig.");
.describe("Not print information during training.");
DMLC_DECLARE_FIELD(sample_type).set_default(0)
.add_enum("uniform", 0)
.add_enum("weighted", 1)
@@ -275,10 +275,12 @@ class GBTree : public GradientBooster {
this->PredPath(p_fmat, out_preds, ntree_limit);
}
std::vector<std::string> Dump2Text(const FeatureMap& fmap, int option) const override {
std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const override {
std::vector<std::string> dump;
for (size_t i = 0; i < trees.size(); i++) {
dump.push_back(trees[i]->Dump2Text(fmap, option & 1));
dump.push_back(trees[i]->DumpModel(fmap, with_stats, format));
}
return dump;
}

View File

@@ -25,8 +25,10 @@ bool Learner::AllowLazyCheckPoint() const {
}
std::vector<std::string>
Learner::Dump2Text(const FeatureMap& fmap, int option) const {
return gbm_->Dump2Text(fmap, option);
Learner::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
return gbm_->DumpModel(fmap, with_stats, format);
}

View File

@@ -15,19 +15,33 @@ namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam);
}
// internal function to dump regression tree to text
void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
const RegTree& tree,
const FeatureMap& fmap,
int nid, int depth, bool with_stats) {
for (int i = 0; i < depth; ++i) {
fo << '\t';
void DumpRegTree(std::stringstream& fo, // NOLINT(*)
const RegTree& tree,
const FeatureMap& fmap,
int nid, int depth, int add_comma,
bool with_stats, std::string format) {
if (format == "json") {
if (add_comma) fo << ",";
if (depth != 0) fo << std::endl;
for (int i = 0; i < depth+1; ++i) fo << " ";
} else {
for (int i = 0; i < depth; ++i) fo << '\t';
}
if (tree[nid].is_leaf()) {
fo << nid << ":leaf=" << tree[nid].leaf_value();
if (with_stats) {
fo << ",cover=" << tree.stat(nid).sum_hess;
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"leaf\": " << tree[nid].leaf_value();
if (with_stats) {
fo << ", \"cover\": " << tree.stat(nid).sum_hess;
}
fo << " }";
} else {
fo << nid << ":leaf=" << tree[nid].leaf_value();
if (with_stats) {
fo << ",cover=" << tree.stat(nid).sum_hess;
}
fo << '\n';
}
fo << '\n';
} else {
// right then left,
bst_float cond = tree[nid].split_cond();
@@ -37,47 +51,101 @@ void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
case FeatureMap::kIndicator: {
int nyes = tree[nid].default_left() ?
tree[nid].cright() : tree[nid].cleft();
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].cdefault();
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.name(split_index) << "\""
<< ", \"yes\": " << nyes
<< ", \"no\": " << tree[nid].cdefault();
} else {
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].cdefault();
}
break;
}
case FeatureMap::kInteger: {
fo << nid << ":[" << fmap.name(split_index) << "<"
<< int(float(cond)+1.0f)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.name(split_index) << "\""
<< ", \"split_condition\": " << int(float(cond) + 1.0f)
<< ", \"yes\": " << tree[nid].cleft()
<< ", \"no\": " << tree[nid].cright()
<< ", \"missing\": " << tree[nid].cdefault();
} else {
fo << nid << ":[" << fmap.name(split_index) << "<"
<< int(float(cond)+1.0f)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
}
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
break;
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": \"" << fmap.name(split_index) << "\""
<< ", \"split_condition\": " << float(cond)
<< ", \"yes\": " << tree[nid].cleft()
<< ", \"no\": " << tree[nid].cright()
<< ", \"missing\": " << tree[nid].cdefault();
} else {
fo << nid << ":[" << fmap.name(split_index) << "<" << float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
}
break;
}
default: LOG(FATAL) << "unknown fmap type";
}
} else {
fo << nid << ":[f" << split_index << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
if (format == "json") {
fo << "{ \"nodeid\": " << nid
<< ", \"depth\": " << depth
<< ", \"split\": " << split_index
<< ", \"split_condition\": " << float(cond)
<< ", \"yes\": " << tree[nid].cleft()
<< ", \"no\": " << tree[nid].cright()
<< ", \"missing\": " << tree[nid].cdefault();
} else {
fo << nid << ":[f" << split_index << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
}
}
if (with_stats) {
fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess;
if (format == "json") {
fo << ", \"gain\": " << tree.stat(nid).loss_chg
<< ", \"cover\": " << tree.stat(nid).sum_hess;
} else {
fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess;
}
}
if (format == "json") {
fo << ", \"children\": [";
} else {
fo << '\n';
}
DumpRegTree(fo, tree, fmap, tree[nid].cleft(), depth + 1, false, with_stats, format);
DumpRegTree(fo, tree, fmap, tree[nid].cright(), depth + 1, true, with_stats, format);
if (format == "json") {
fo << std::endl;
for (int i = 0; i < depth+1; ++i) fo << " ";
fo << "]}";
}
fo << '\n';
DumpRegTree2Text(fo, tree, fmap, tree[nid].cleft(), depth + 1, with_stats);
DumpRegTree2Text(fo, tree, fmap, tree[nid].cright(), depth + 1, with_stats);
}
}
std::string RegTree::Dump2Text(const FeatureMap& fmap, bool with_stats) const {
std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
std::stringstream fo("");
for (int i = 0; i < param.num_roots; ++i) {
DumpRegTree2Text(fo, *this, fmap, i, 0, with_stats);
DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format);
}
return fo.str();
}