[TREE] Refactor histmaker
This commit is contained in:
parent
468bc7725a
commit
52227a8920
@ -206,6 +206,16 @@ class BaseMaker: public TreeUpdater {
|
|||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
// set the positions in the nondefault
|
// set the positions in the nondefault
|
||||||
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
||||||
|
this->SetDefaultPostion(p_fmat, tree);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief helper function to set the non-leaf positions to default direction.
|
||||||
|
* This function can be applied multiple times and will get the same result.
|
||||||
|
* \param p_fmat feature matrix needed for tree construction
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
inline void SetDefaultPostion(DMatrix *p_fmat,
|
||||||
|
const RegTree &tree) {
|
||||||
// set rest of instances to default position
|
// set rest of instances to default position
|
||||||
const RowSet &rowset = p_fmat->buffered_rowset();
|
const RowSet &rowset = p_fmat->buffered_rowset();
|
||||||
// set default direct nodes to default
|
// set default direct nodes to default
|
||||||
@ -222,7 +232,7 @@ class BaseMaker: public TreeUpdater {
|
|||||||
if (tree[nid].cright() == -1) {
|
if (tree[nid].cright() == -1) {
|
||||||
position[ridx] = ~nid;
|
position[ridx] = ~nid;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// push to default branch
|
// push to default branch
|
||||||
if (tree[nid].default_left()) {
|
if (tree[nid].default_left()) {
|
||||||
this->SetEncodePosition(ridx, tree[nid].cleft());
|
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||||
@ -234,16 +244,16 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief this is helper function uses column based data structure,
|
* \brief this is helper function uses column based data structure,
|
||||||
* update all positions into nondefault branch, if any, ignore the default branch
|
|
||||||
* \param nodes the set of nodes that contains the split to be used
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
* \param p_fmat feature matrix needed for tree construction
|
|
||||||
* \param tree the regression tree structure
|
* \param tree the regression tree structure
|
||||||
|
* \param out_split_set The split index set
|
||||||
*/
|
*/
|
||||||
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
inline void GetSplitSet(const std::vector<int> &nodes,
|
||||||
DMatrix *p_fmat,
|
const RegTree &tree,
|
||||||
const RegTree &tree) {
|
std::vector<unsigned>* out_split_set) {
|
||||||
|
std::vector<unsigned>& fsplits = *out_split_set;
|
||||||
|
fsplits.clear();
|
||||||
// step 1, classify the non-default data into right places
|
// step 1, classify the non-default data into right places
|
||||||
std::vector<unsigned> fsplits;
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||||
const int nid = nodes[i];
|
const int nid = nodes[i];
|
||||||
if (!tree[nid].is_leaf()) {
|
if (!tree[nid].is_leaf()) {
|
||||||
@ -252,7 +262,19 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
std::sort(fsplits.begin(), fsplits.end());
|
std::sort(fsplits.begin(), fsplits.end());
|
||||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief this is helper function uses column based data structure,
|
||||||
|
* update all positions into nondefault branch, if any, ignore the default branch
|
||||||
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
|
* \param p_fmat feature matrix needed for tree construction
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
const RegTree &tree) {
|
||||||
|
std::vector<unsigned> fsplits;
|
||||||
|
this->GetSplitSet(nodes, tree, &fsplits);
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch &batch = iter->Value();
|
||||||
|
|||||||
@ -355,8 +355,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||||
const RegTree &tree) override {
|
const RegTree &tree) override {
|
||||||
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
||||||
|
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
|
||||||
}
|
}
|
||||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
@ -388,14 +389,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
summary_array[i].Reserve(max_size);
|
summary_array[i].Reserve(max_size);
|
||||||
}
|
}
|
||||||
// if it is C++11, use lazy evaluation for Allreduce
|
{
|
||||||
#if __cplusplus >= 201103L
|
|
||||||
auto lazy_get_summary = [&]()
|
|
||||||
#endif
|
|
||||||
{
|
|
||||||
// get smmary
|
// get smmary
|
||||||
thread_sketch.resize(this->get_nthread());
|
thread_sketch.resize(this->get_nthread());
|
||||||
// number of rows in
|
// number of rows in data
|
||||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
const size_t nrows = p_fmat->buffered_rowset().size();
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
||||||
@ -422,15 +419,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
summary_array[i].SetPrune(out, max_size);
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
CHECK_EQ(summary_array.size(), sketchs.size());
|
CHECK_EQ(summary_array.size(), sketchs.size());
|
||||||
};
|
}
|
||||||
if (summary_array.size() != 0) {
|
if (summary_array.size() != 0) {
|
||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
#if __cplusplus >= 201103L
|
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array),
|
|
||||||
nbytes, summary_array.size(), lazy_get_summary);
|
|
||||||
#else
|
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
@ -617,6 +609,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
std::vector<int> feat2workindex;
|
std::vector<int> feat2workindex;
|
||||||
// set of index from fset that are real
|
// set of index from fset that are real
|
||||||
std::vector<bst_uint> freal_set;
|
std::vector<bst_uint> freal_set;
|
||||||
|
// set of index from that are split candidates.
|
||||||
|
std::vector<bst_uint> fsplit_set;
|
||||||
// thread temp data
|
// thread temp data
|
||||||
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch;
|
||||||
// used to hold statistics
|
// used to hold statistics
|
||||||
@ -633,6 +627,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template<typename TStats>
|
template<typename TStats>
|
||||||
class QuantileHistMaker: public HistMaker<TStats> {
|
class QuantileHistMaker: public HistMaker<TStats> {
|
||||||
protected:
|
protected:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user