[TREE] Refactor histmaker
This commit is contained in:
parent
468bc7725a
commit
52227a8920
@ -206,6 +206,16 @@ class BaseMaker: public TreeUpdater {
|
||||
const RegTree &tree) {
|
||||
// set the positions in the nondefault
|
||||
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
||||
this->SetDefaultPostion(p_fmat, tree);
|
||||
}
|
||||
/*!
|
||||
* \brief helper function to set the non-leaf positions to default direction.
|
||||
* This function can be applied multiple times and will get the same result.
|
||||
* \param p_fmat feature matrix needed for tree construction
|
||||
* \param tree the regression tree structure
|
||||
*/
|
||||
inline void SetDefaultPostion(DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
// set rest of instances to default position
|
||||
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||
// set default direct nodes to default
|
||||
@ -222,7 +232,7 @@ class BaseMaker: public TreeUpdater {
|
||||
if (tree[nid].cright() == -1) {
|
||||
position[ridx] = ~nid;
|
||||
}
|
||||
} else {
|
||||
} else {
|
||||
// push to default branch
|
||||
if (tree[nid].default_left()) {
|
||||
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||
@ -234,16 +244,16 @@ class BaseMaker: public TreeUpdater {
|
||||
}
|
||||
/*!
|
||||
* \brief this is helper function uses column based data structure,
|
||||
* update all positions into nondefault branch, if any, ignore the default branch
|
||||
* \param nodes the set of nodes that contains the split to be used
|
||||
* \param p_fmat feature matrix needed for tree construction
|
||||
* \param tree the regression tree structure
|
||||
* \param out_split_set The split index set
|
||||
*/
|
||||
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
||||
DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
inline void GetSplitSet(const std::vector<int> &nodes,
|
||||
const RegTree &tree,
|
||||
std::vector<unsigned>* out_split_set) {
|
||||
std::vector<unsigned>& fsplits = *out_split_set;
|
||||
fsplits.clear();
|
||||
// step 1, classify the non-default data into right places
|
||||
std::vector<unsigned> fsplits;
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
const int nid = nodes[i];
|
||||
if (!tree[nid].is_leaf()) {
|
||||
@ -252,7 +262,19 @@ class BaseMaker: public TreeUpdater {
|
||||
}
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
|
||||
}
|
||||
/*!
|
||||
* \brief this is helper function uses column based data structure,
|
||||
* update all positions into nondefault branch, if any, ignore the default branch
|
||||
* \param nodes the set of nodes that contains the split to be used
|
||||
* \param p_fmat feature matrix needed for tree construction
|
||||
* \param tree the regression tree structure
|
||||
*/
|
||||
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
||||
DMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
std::vector<unsigned> fsplits;
|
||||
this->GetSplitSet(nodes, tree, &fsplits);
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
|
||||
@ -355,8 +355,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
#endif
|
||||
}
|
||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||
const RegTree &tree) override {
|
||||
const RegTree &tree) override {
|
||||
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
||||
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
|
||||
}
|
||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
@ -388,14 +389,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
summary_array[i].Reserve(max_size);
|
||||
}
|
||||
// if it is C++11, use lazy evaluation for Allreduce
|
||||
#if __cplusplus >= 201103L
|
||||
auto lazy_get_summary = [&]()
|
||||
#endif
|
||||
{
|
||||
{
|
||||
// get smmary
|
||||
thread_sketch.resize(this->get_nthread());
|
||||
// number of rows in
|
||||
// number of rows in data
|
||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
||||
// start accumulating statistics
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
||||
@ -422,15 +419,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
CHECK_EQ(summary_array.size(), sketchs.size());
|
||||
};
|
||||
}
|
||||
if (summary_array.size() != 0) {
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
#if __cplusplus >= 201103L
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array),
|
||||
nbytes, summary_array.size(), lazy_get_summary);
|
||||
#else
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
#endif
|
||||
}
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
@ -617,6 +609,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
std::vector<int> feat2workindex;
|
||||
// set of index from fset that are real
|
||||
std::vector<bst_uint> freal_set;
|
||||
// set of index from that are split candidates.
|
||||
std::vector<bst_uint> fsplit_set;
|
||||
// thread temp data
|
||||
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||
// used to hold statistics
|
||||
@ -633,6 +627,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
protected:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user