[TREE] Finalize regression tree refactor

This commit is contained in:
tqchen
2016-01-01 02:54:28 -08:00
parent 844e8a153d
commit a62a66d545
4 changed files with 227 additions and 113 deletions

79
src/tree/tree_model.cc Normal file
View File

@@ -0,0 +1,79 @@
/*!
* Copyright 2015 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
#include <xgboost/tree_model.h>
#include <sstream>
namespace xgboost {
// internal function to dump regression tree to text
void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
const RegTree& tree,
const FeatureMap& fmap,
int nid, int depth, bool with_stats) {
for (int i = 0; i < depth; ++i) {
fo << '\t';
}
if (tree[nid].is_leaf()) {
fo << nid << ":leaf=" << tree[nid].leaf_value();
if (with_stats) {
fo << ",cover=" << tree.stat(nid).sum_hess;
}
fo << '\n';
} else {
// right then left,
bst_float cond = tree[nid].split_cond();
const unsigned split_index = tree[nid].split_index();
if (split_index < fmap.size()) {
switch (fmap.type(split_index)) {
case FeatureMap::kIndicator: {
int nyes = tree[nid].default_left() ?
tree[nid].cright() : tree[nid].cleft();
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].cdefault();
break;
}
case FeatureMap::kInteger: {
fo << nid << ":[" << fmap.name(split_index) << "<"
<< int(float(cond)+1.0f)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
break;
}
default: LOG(FATAL) << "unknown fmap type";
}
} else {
fo << nid << ":[f" << split_index << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
}
if (with_stats) {
fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess;
}
fo << '\n';
DumpRegTree2Text(fo, tree, fmap, tree[nid].cleft(), depth + 1, with_stats);
DumpRegTree2Text(fo, tree, fmap, tree[nid].cright(), depth + 1, with_stats);
}
}
std::string RegTree::Dump2Text(const FeatureMap& fmap, bool with_stats) const {
std::stringstream fo("");
for (int i = 0; i < param.num_roots; ++i) {
DumpRegTree2Text(fo, *this, fmap, i, 0, with_stats);
}
return fo.str();
}
} // namespace xgboost