[TREE] Refactor histmaker

This commit is contained in:
tqchen 2016-01-19 20:44:29 -08:00
parent 468bc7725a
commit 52227a8920
2 changed files with 38 additions and 21 deletions

View File

@ -206,6 +206,16 @@ class BaseMaker: public TreeUpdater {
const RegTree &tree) { const RegTree &tree) {
// set the positions in the nondefault // set the positions in the nondefault
this->SetNonDefaultPositionCol(nodes, p_fmat, tree); this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
this->SetDefaultPostion(p_fmat, tree);
}
/*!
* \brief helper function to set the non-leaf positions to default direction.
* This function can be applied multiple times and will get the same result.
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
*/
inline void SetDefaultPostion(DMatrix *p_fmat,
const RegTree &tree) {
// set rest of instances to default position // set rest of instances to default position
const RowSet &rowset = p_fmat->buffered_rowset(); const RowSet &rowset = p_fmat->buffered_rowset();
// set default direct nodes to default // set default direct nodes to default
@ -222,7 +232,7 @@ class BaseMaker: public TreeUpdater {
if (tree[nid].cright() == -1) { if (tree[nid].cright() == -1) {
position[ridx] = ~nid; position[ridx] = ~nid;
} }
} else { } else {
// push to default branch // push to default branch
if (tree[nid].default_left()) { if (tree[nid].default_left()) {
this->SetEncodePosition(ridx, tree[nid].cleft()); this->SetEncodePosition(ridx, tree[nid].cleft());
@ -234,16 +244,16 @@ class BaseMaker: public TreeUpdater {
} }
/*! /*!
* \brief this is helper function uses column based data structure, * \brief this is helper function uses column based data structure,
* update all positions into nondefault branch, if any, ignore the default branch
* \param nodes the set of nodes that contains the split to be used * \param nodes the set of nodes that contains the split to be used
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure * \param tree the regression tree structure
* \param out_split_set The split index set
*/ */
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes, inline void GetSplitSet(const std::vector<int> &nodes,
DMatrix *p_fmat, const RegTree &tree,
const RegTree &tree) { std::vector<unsigned>* out_split_set) {
std::vector<unsigned>& fsplits = *out_split_set;
fsplits.clear();
// step 1, classify the non-default data into right places // step 1, classify the non-default data into right places
std::vector<unsigned> fsplits;
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
const int nid = nodes[i]; const int nid = nodes[i];
if (!tree[nid].is_leaf()) { if (!tree[nid].is_leaf()) {
@ -252,7 +262,19 @@ class BaseMaker: public TreeUpdater {
} }
std::sort(fsplits.begin(), fsplits.end()); std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
}
/*!
* \brief this is helper function uses column based data structure,
* update all positions into nondefault branch, if any, ignore the default branch
* \param nodes the set of nodes that contains the split to be used
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
*/
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
DMatrix *p_fmat,
const RegTree &tree) {
std::vector<unsigned> fsplits;
this->GetSplitSet(nodes, tree, &fsplits);
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits); dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
while (iter->Next()) { while (iter->Next()) {
const ColBatch &batch = iter->Value(); const ColBatch &batch = iter->Value();

View File

@ -355,8 +355,9 @@ class CQHistMaker: public HistMaker<TStats> {
#endif #endif
} }
void ResetPositionAfterSplit(DMatrix *p_fmat, void ResetPositionAfterSplit(DMatrix *p_fmat,
const RegTree &tree) override { const RegTree &tree) override {
this->ResetPositionCol(this->qexpand, p_fmat, tree); this->ResetPositionCol(this->qexpand, p_fmat, tree);
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
} }
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair, void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
@ -388,14 +389,10 @@ class CQHistMaker: public HistMaker<TStats> {
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
summary_array[i].Reserve(max_size); summary_array[i].Reserve(max_size);
} }
// if it is C++11, use lazy evaluation for Allreduce {
#if __cplusplus >= 201103L
auto lazy_get_summary = [&]()
#endif
{
// get smmary // get smmary
thread_sketch.resize(this->get_nthread()); thread_sketch.resize(this->get_nthread());
// number of rows in // number of rows in data
const size_t nrows = p_fmat->buffered_rowset().size(); const size_t nrows = p_fmat->buffered_rowset().size();
// start accumulating statistics // start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set); dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
@ -422,15 +419,10 @@ class CQHistMaker: public HistMaker<TStats> {
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
CHECK_EQ(summary_array.size(), sketchs.size()); CHECK_EQ(summary_array.size(), sketchs.size());
}; }
if (summary_array.size() != 0) { if (summary_array.size() != 0) {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
#if __cplusplus >= 201103L
sreducer.Allreduce(dmlc::BeginPtr(summary_array),
nbytes, summary_array.size(), lazy_get_summary);
#else
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
#endif
} }
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
@ -617,6 +609,8 @@ class CQHistMaker: public HistMaker<TStats> {
std::vector<int> feat2workindex; std::vector<int> feat2workindex;
// set of index from fset that are real // set of index from fset that are real
std::vector<bst_uint> freal_set; std::vector<bst_uint> freal_set;
// set of index from that are split candidates.
std::vector<bst_uint> fsplit_set;
// thread temp data // thread temp data
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch; std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
// used to hold statistics // used to hold statistics
@ -633,6 +627,7 @@ class CQHistMaker: public HistMaker<TStats> {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs; std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
}; };
template<typename TStats> template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> { class QuantileHistMaker: public HistMaker<TStats> {
protected: protected: