[TREE] Refactor histmaker

This commit is contained in:
tqchen 2016-01-19 20:44:29 -08:00
parent 468bc7725a
commit 52227a8920
2 changed files with 38 additions and 21 deletions

View File

@ -206,6 +206,16 @@ class BaseMaker: public TreeUpdater {
const RegTree &tree) {
// set the positions in the nondefault
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
this->SetDefaultPostion(p_fmat, tree);
}
/*!
* \brief helper function to set the non-leaf positions to default direction.
* This function can be applied multiple times and will get the same result.
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
*/
inline void SetDefaultPostion(DMatrix *p_fmat,
const RegTree &tree) {
// set rest of instances to default position
const RowSet &rowset = p_fmat->buffered_rowset();
// set default direct nodes to default
@ -222,7 +232,7 @@ class BaseMaker: public TreeUpdater {
if (tree[nid].cright() == -1) {
position[ridx] = ~nid;
}
} else {
} else {
// push to default branch
if (tree[nid].default_left()) {
this->SetEncodePosition(ridx, tree[nid].cleft());
@ -234,16 +244,16 @@ class BaseMaker: public TreeUpdater {
}
/*!
* \brief this is helper function uses column based data structure,
* update all positions into nondefault branch, if any, ignore the default branch
* \param nodes the set of nodes that contains the split to be used
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
* \param out_split_set The split index set
*/
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
DMatrix *p_fmat,
const RegTree &tree) {
inline void GetSplitSet(const std::vector<int> &nodes,
const RegTree &tree,
std::vector<unsigned>* out_split_set) {
std::vector<unsigned>& fsplits = *out_split_set;
fsplits.clear();
// step 1, classify the non-default data into right places
std::vector<unsigned> fsplits;
for (size_t i = 0; i < nodes.size(); ++i) {
const int nid = nodes[i];
if (!tree[nid].is_leaf()) {
@ -252,7 +262,19 @@ class BaseMaker: public TreeUpdater {
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
}
/*!
* \brief this is helper function uses column based data structure,
* update all positions into nondefault branch, if any, ignore the default branch
* \param nodes the set of nodes that contains the split to be used
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
*/
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
DMatrix *p_fmat,
const RegTree &tree) {
std::vector<unsigned> fsplits;
this->GetSplitSet(nodes, tree, &fsplits);
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
while (iter->Next()) {
const ColBatch &batch = iter->Value();

View File

@ -355,8 +355,9 @@ class CQHistMaker: public HistMaker<TStats> {
#endif
}
void ResetPositionAfterSplit(DMatrix *p_fmat,
const RegTree &tree) override {
const RegTree &tree) override {
this->ResetPositionCol(this->qexpand, p_fmat, tree);
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
}
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
@ -388,14 +389,10 @@ class CQHistMaker: public HistMaker<TStats> {
for (size_t i = 0; i < sketchs.size(); ++i) {
summary_array[i].Reserve(max_size);
}
// if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L
auto lazy_get_summary = [&]()
#endif
{
{
// get smmary
thread_sketch.resize(this->get_nthread());
// number of rows in
// number of rows in data
const size_t nrows = p_fmat->buffered_rowset().size();
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
@ -422,15 +419,10 @@ class CQHistMaker: public HistMaker<TStats> {
summary_array[i].SetPrune(out, max_size);
}
CHECK_EQ(summary_array.size(), sketchs.size());
};
}
if (summary_array.size() != 0) {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
#if __cplusplus >= 201103L
sreducer.Allreduce(dmlc::BeginPtr(summary_array),
nbytes, summary_array.size(), lazy_get_summary);
#else
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
#endif
}
// now we get the final result of sketch, setup the cut
this->wspace.cut.clear();
@ -617,6 +609,8 @@ class CQHistMaker: public HistMaker<TStats> {
std::vector<int> feat2workindex;
// set of index from fset that are real
std::vector<bst_uint> freal_set;
// set of index from that are split candidates.
std::vector<bst_uint> fsplit_set;
// thread temp data
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
// used to hold statistics
@ -633,6 +627,7 @@ class CQHistMaker: public HistMaker<TStats> {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
};
template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> {
protected: