[TREE] Finalize regression tree refactor

This commit is contained in:
tqchen 2016-01-01 02:54:28 -08:00
parent 844e8a153d
commit a62a66d545
4 changed files with 227 additions and 113 deletions

View File

@ -0,0 +1,92 @@
/*!
* Copyright 2014 by Contributors
* \file feature_map.h
* \brief Feature map data structure to help visualization and model dump.
* \author Tianqi Chen
*/
#ifndef XGBOOST_FEATURE_MAP_H_
#define XGBOOST_FEATURE_MAP_H_
#include <vector>
#include <string>
#include <cstring>
#include <iostream>
namespace xgboost {
/*!
* \brief Feature map data structure to help text model dump.
* TODO(tqchen) consider make it even more lightweight.
*/
class FeatureMap {
public:
/*! \brief type of feature maps */
enum Type {
kIndicator = 0,
kQuantitive = 1,
kInteger = 2,
kFloat = 3
};
/*!
* \brief load feature map from input stream
* \param is Input text stream
*/
inline void LoadText(std::istream& is) { // NOLINT(*)
int fid;
std::string fname, ftype;
while (is >> fid >> fname >> ftype) {
this->PushBack(fid, fname.c_str(), ftype.c_str());
}
}
/*!
* \brief push back feature map.
* \param fid The feature index.
* \param fname The feature name.
* \param ftype The feature type.
*/
inline void PushBack(int fid, const char *fname, const char *ftype) {
CHECK_EQ(fid, static_cast<int>(names_.size()));
names_.push_back(std::string(fname));
types_.push_back(GetType(ftype));
}
/*! \brief clear the feature map */
inline void Clear() {
names_.clear();
types_.clear();
}
/*! \return number of known features */
inline size_t size() const {
return names_.size();
}
/*! \return name of specific feature */
inline const char* name(size_t idx) const {
CHECK_LT(idx, names_.size()) << "FeatureMap feature index exceed bound";
return names_[idx].c_str();
}
/*! \return type of specific feature */
const Type type(size_t idx) const {
CHECK_LT(idx, names_.size()) << "FeatureMap feature index exceed bound";
return types_[idx];
}
private:
/*!
* \return feature type enum given name.
* \param tname The type name.
* \return The translated type.
*/
inline static Type GetType(const char* tname) {
using namespace std;
if (!strcmp("i", tname)) return kIndicator;
if (!strcmp("q", tname)) return kQuantitive;
if (!strcmp("int", tname)) return kInteger;
if (!strcmp("float", tname)) return kFloat;
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
return kIndicator;
}
/*! \brief name of the feature */
std::vector<std::string> names_;
/*! \brief type of the feature */
std::vector<Type> types_;
};
} // namespace xgboost
#endif // XGBOOST_FEATURE_MAP_H_

View File

@ -12,11 +12,55 @@
#include <dmlc/parameter.h>
#include <limits>
#include <vector>
#include <string>
#include <cstring>
#include <algorithm>
#include "./base.h"
#include "./data.h"
#include "./feature_map.h"
namespace xgboost {
/*! \brief meta parameters of the tree */
struct TreeParam : public dmlc::Parameter<TreeParam> {
/*! \brief number of start root */
int num_roots;
/*! \brief total number of nodes */
int num_nodes;
/*!\brief number of deleted nodes */
int num_deleted;
/*! \brief maximum depth, this is a statistics of the tree */
int max_depth;
/*! \brief number of features used for tree construction */
int num_feature;
/*!
* \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree
*/
int size_leaf_vector;
/*! \brief reserved part, make sure alignment works for 64bit */
int reserved[31];
/*! \brief constructor */
TreeParam() {
// assert compact alignment
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
"TreeParam: 64 bit align");
std::memset(this, 0, sizeof(TreeParam));
num_nodes = num_roots = 1;
}
// declare the parameters
DMLC_DECLARE_PARAMETER(TreeParam) {
// only declare the parameters that can be set by the user.
// other arguments are set by the algorithm.
DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1)
.describe("Number of start root of trees.");
DMLC_DECLARE_FIELD(num_feature)
.describe("Number of features used in tree construction.");
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
.describe("Size of leaf vector, reserved for vector tree");
}
};
/*!
* \brief template class of TreeModel
* \tparam TSplitCond data type to indicate split condition
@ -29,34 +73,6 @@ class TreeModel {
typedef TNodeStat NodeStat;
/*! \brief auxiliary statistics of node to help tree building */
typedef TSplitCond SplitCond;
/*! \brief parameters of the tree */
struct TreeParam {
/*! \brief number of start root */
int num_roots;
/*! \brief total number of nodes */
int num_nodes;
/*!\brief number of deleted nodes */
int num_deleted;
/*! \brief maximum depth, this is a statistics of the tree */
int max_depth;
/*! \brief number of features used for tree construction */
int num_feature;
/*!
* \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree
*/
int size_leaf_vector;
/*! \brief reserved part */
int reserved[31];
/*! \brief constructor */
TreeParam() {
// assert compact alignment
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
"TreeParam: 64 bit align");
std::memset(this, 0, sizeof(TreeParam));
num_nodes = num_roots = 1;
}
};
/*! \brief tree node */
class Node {
public:
@ -259,6 +275,10 @@ class TreeModel {
inline NodeStat& stat(int nid) {
return stats[nid];
}
/*! \brief get node statistics given nid */
inline const NodeStat& stat(int nid) const {
return stats[nid];
}
/*! \brief get leaf vector given nid */
inline bst_float* leafvec(int nid) {
if (leaf_vector.size() == 0) return nullptr;
@ -444,7 +464,7 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param root_id starting root index of the instance
* \return the leaf index of the given feature
*/
inline int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const;
inline int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const;
/*!
* \brief get the prediction of regression tree, only accepts dense feature vector
* \param feats dense feature vector, if the feature is missing the field is set to NaN
@ -459,6 +479,13 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param is_unknown Whether current required feature is missing.
*/
inline int GetNext(int pid, float fvalue, bool is_unknown) const;
/*!
* \brief dump model to text string
* \param fmap feature map of feature types
* \param with_stats whether dump out statistics as well
* \return the string of dumped model
*/
std::string Dump2Text(const FeatureMap& fmap, bool with_stats) const;
};
// implementations of inline functions
@ -518,6 +545,5 @@ inline int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const {
}
}
}
} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_

View File

@ -1,83 +0,0 @@
/*!
* Copyright 2014 by Contributors
* \file fmap.h
* \brief helper class that holds the feature names and interpretations
* \author Tianqi Chen
*/
#ifndef XGBOOST_UTILS_FMAP_H_
#define XGBOOST_UTILS_FMAP_H_
#include <vector>
#include <string>
#include <cstring>
#include "./utils.h"
namespace xgboost {
namespace utils {
/*! \brief helper class that holds the feature names and interpretations */
class FeatMap {
public:
enum Type {
kIndicator = 0,
kQuantitive = 1,
kInteger = 2,
kFloat = 3
};
// function definitions
/*! \brief load feature map from text format */
inline void LoadText(const char *fname) {
std::FILE *fi = utils::FopenCheck(fname, "r");
this->LoadText(fi);
std::fclose(fi);
}
/*! \brief load feature map from text format */
inline void LoadText(std::FILE *fi) {
int fid;
char fname[1256], ftype[1256];
while (std::fscanf(fi, "%d\t%[^\t]\t%s\n", &fid, fname, ftype) == 3) {
this->PushBack(fid, fname, ftype);
}
}
/*!\brief push back feature map */
inline void PushBack(int fid, const char *fname, const char *ftype) {
utils::Check(fid == static_cast<int>(names_.size()), "invalid fmap format");
names_.push_back(std::string(fname));
types_.push_back(GetType(ftype));
}
inline void Clear(void) {
names_.clear(); types_.clear();
}
/*! \brief number of known features */
size_t size(void) const {
return names_.size();
}
/*! \brief return name of specific feature */
const char* name(size_t idx) const {
utils::Assert(idx < names_.size(), "utils::FMap::name feature index exceed bound");
return names_[idx].c_str();
}
/*! \brief return type of specific feature */
const Type& type(size_t idx) const {
utils::Assert(idx < names_.size(), "utils::FMap::type feature index exceed bound");
return types_[idx];
}
private:
inline static Type GetType(const char *tname) {
using namespace std;
if (!strcmp("i", tname)) return kIndicator;
if (!strcmp("q", tname)) return kQuantitive;
if (!strcmp("int", tname)) return kInteger;
if (!strcmp("float", tname)) return kFloat;
utils::Error("unknown feature type, use i for indicator and q for quantity");
return kIndicator;
}
/*! \brief name of the feature */
std::vector<std::string> names_;
/*! \brief type of the feature */
std::vector<Type> types_;
};
} // namespace utils
} // namespace xgboost
#endif // XGBOOST_UTILS_FMAP_H_

79
src/tree/tree_model.cc Normal file
View File

@ -0,0 +1,79 @@
/*!
* Copyright 2015 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
#include <xgboost/tree_model.h>
#include <sstream>
namespace xgboost {
// internal function to dump regression tree to text
void DumpRegTree2Text(std::stringstream& fo, // NOLINT(*)
const RegTree& tree,
const FeatureMap& fmap,
int nid, int depth, bool with_stats) {
for (int i = 0; i < depth; ++i) {
fo << '\t';
}
if (tree[nid].is_leaf()) {
fo << nid << ":leaf=" << tree[nid].leaf_value();
if (with_stats) {
fo << ",cover=" << tree.stat(nid).sum_hess;
}
fo << '\n';
} else {
// right then left,
bst_float cond = tree[nid].split_cond();
const unsigned split_index = tree[nid].split_index();
if (split_index < fmap.size()) {
switch (fmap.type(split_index)) {
case FeatureMap::kIndicator: {
int nyes = tree[nid].default_left() ?
tree[nid].cright() : tree[nid].cleft();
fo << nid << ":[" << fmap.name(split_index) << "] yes=" << nyes
<< ",no=" << tree[nid].cdefault();
break;
}
case FeatureMap::kInteger: {
fo << nid << ":[" << fmap.name(split_index) << "<"
<< int(float(cond)+1.0f)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
fo << nid << ":[" << fmap.name(split_index) << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
break;
}
default: LOG(FATAL) << "unknown fmap type";
}
} else {
fo << nid << ":[f" << split_index << "<"<< float(cond)
<< "] yes=" << tree[nid].cleft()
<< ",no=" << tree[nid].cright()
<< ",missing=" << tree[nid].cdefault();
}
if (with_stats) {
fo << ",gain=" << tree.stat(nid).loss_chg << ",cover=" << tree.stat(nid).sum_hess;
}
fo << '\n';
DumpRegTree2Text(fo, tree, fmap, tree[nid].cleft(), depth + 1, with_stats);
DumpRegTree2Text(fo, tree, fmap, tree[nid].cright(), depth + 1, with_stats);
}
}
std::string RegTree::Dump2Text(const FeatureMap& fmap, bool with_stats) const {
std::stringstream fo("");
for (int i = 0; i < param.num_roots; ++i) {
DumpRegTree2Text(fo, *this, fmap, i, 0, with_stats);
}
return fo.str();
}
} // namespace xgboost