[TREE] Cleanup some functions, add utility function for two pass
This commit is contained in:
parent
52227a8920
commit
523afcbcd2
@ -242,6 +242,45 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief this is helper function uses column based data structure,
|
||||||
|
* to CORRECT the positions of non-default directions that WAS set to default
|
||||||
|
* before calling this function.
|
||||||
|
* \param batch The column batch
|
||||||
|
* \param sorted_split_set The set of index that contains split solutions.
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
inline void CorrectNonDefaultPositionByBatch(
|
||||||
|
const ColBatch& batch,
|
||||||
|
const std::vector<bst_uint> &sorted_split_set,
|
||||||
|
const RegTree &tree) {
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
ColBatch::Inst col = batch[i];
|
||||||
|
const bst_uint fid = batch.col_index[i];
|
||||||
|
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
|
||||||
|
|
||||||
|
if (it != sorted_split_set.end() && *it == fid) {
|
||||||
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
|
const bst_uint ridx = col[j].index;
|
||||||
|
const float fvalue = col[j].fvalue;
|
||||||
|
const int nid = this->DecodePosition(ridx);
|
||||||
|
CHECK(tree[nid].is_leaf());
|
||||||
|
int pid = tree[nid].parent();
|
||||||
|
|
||||||
|
// go back to parent, correct those who are not default
|
||||||
|
if (!tree[nid].is_root() && tree[pid].split_index() == fid) {
|
||||||
|
if (fvalue < tree[pid].split_cond()) {
|
||||||
|
this->SetEncodePosition(ridx, tree[pid].cleft());
|
||||||
|
} else {
|
||||||
|
this->SetEncodePosition(ridx, tree[pid].cright());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief this is helper function uses column based data structure,
|
* \brief this is helper function uses column based data structure,
|
||||||
* \param nodes the set of nodes that contains the split to be used
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
|
|||||||
@ -127,6 +127,11 @@ class HistMaker: public BaseMaker {
|
|||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
this->InitData(gpair, *p_fmat, *p_tree);
|
this->InitData(gpair, *p_fmat, *p_tree);
|
||||||
this->InitWorkSet(p_fmat, *p_tree, &fwork_set);
|
this->InitWorkSet(p_fmat, *p_tree, &fwork_set);
|
||||||
|
// mark root node as fresh.
|
||||||
|
for (int i = 0; i < p_tree->param.num_roots; ++i) {
|
||||||
|
(*p_tree)[i].set_leaf(0.0f, 0);
|
||||||
|
}
|
||||||
|
|
||||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||||
// reset and propose candidate split
|
// reset and propose candidate split
|
||||||
this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree);
|
this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree);
|
||||||
@ -356,8 +361,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||||
const RegTree &tree) override {
|
const RegTree &tree) override {
|
||||||
|
// remove this reset and do two pass reset on ResetPosAndPropose
|
||||||
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
this->ResetPositionCol(this->qexpand, p_fmat, tree);
|
||||||
this->GetSplitSet(this->qexpand, tree, &fsplit_set);
|
|
||||||
}
|
}
|
||||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
@ -367,18 +372,18 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
// fill in reverse map
|
// fill in reverse map
|
||||||
feat2workindex.resize(tree.param.num_feature);
|
feat2workindex.resize(tree.param.num_feature);
|
||||||
std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
|
std::fill(feat2workindex.begin(), feat2workindex.end(), -1);
|
||||||
freal_set.clear();
|
work_set.clear();
|
||||||
for (size_t i = 0; i < fset.size(); ++i) {
|
for (size_t i = 0; i < fset.size(); ++i) {
|
||||||
if (feat_helper.Type(fset[i]) == 2) {
|
if (feat_helper.Type(fset[i]) == 2) {
|
||||||
feat2workindex[fset[i]] = static_cast<int>(freal_set.size());
|
feat2workindex[fset[i]] = static_cast<int>(work_set.size());
|
||||||
freal_set.push_back(fset[i]);
|
work_set.push_back(fset[i]);
|
||||||
} else {
|
} else {
|
||||||
feat2workindex[fset[i]] = -2;
|
feat2workindex[fset[i]] = -2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
const size_t work_set_size = work_set.size();
|
||||||
&thread_stats, &node_stats);
|
|
||||||
sketchs.resize(this->qexpand.size() * freal_set.size());
|
sketchs.resize(this->qexpand.size() * work_set_size);
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
}
|
}
|
||||||
@ -392,10 +397,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
{
|
{
|
||||||
// get smmary
|
// get smmary
|
||||||
thread_sketch.resize(this->get_nthread());
|
thread_sketch.resize(this->get_nthread());
|
||||||
// number of rows in data
|
|
||||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(freal_set);
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(work_set);
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch &batch = iter->Value();
|
||||||
@ -406,9 +410,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
int offset = feat2workindex[batch.col_index[i]];
|
int offset = feat2workindex[batch.col_index[i]];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
this->UpdateSketchCol(gpair, batch[i], tree,
|
this->UpdateSketchCol(gpair, batch[i], tree,
|
||||||
node_stats,
|
work_set_size, offset,
|
||||||
freal_set, offset,
|
|
||||||
batch[i].length == nrows,
|
|
||||||
&thread_sketch[omp_get_thread_num()]);
|
&thread_sketch[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -424,6 +426,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
}
|
}
|
||||||
|
// update node statistics.
|
||||||
|
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||||
|
&thread_stats, &node_stats);
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
this->wspace.rptr.clear();
|
this->wspace.rptr.clear();
|
||||||
@ -432,7 +437,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
for (size_t i = 0; i < fset.size(); ++i) {
|
for (size_t i = 0; i < fset.size(); ++i) {
|
||||||
int offset = feat2workindex[fset[i]];
|
int offset = feat2workindex[fset[i]];
|
||||||
if (offset >= 0) {
|
if (offset >= 0) {
|
||||||
const WXQSketch::Summary &a = summary_array[wid * freal_set.size() + offset];
|
const WXQSketch::Summary &a = summary_array[wid * work_set_size + offset];
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
for (size_t i = 1; i < a.size; ++i) {
|
||||||
bst_float cpt = a.data[i].value - rt_eps;
|
bst_float cpt = a.data[i].value - rt_eps;
|
||||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||||
@ -518,10 +523,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||||
const ColBatch::Inst &c,
|
const ColBatch::Inst &c,
|
||||||
const RegTree &tree,
|
const RegTree &tree,
|
||||||
const std::vector<TStats> &nstats,
|
size_t work_set_size,
|
||||||
const std::vector<bst_uint> &frealset,
|
|
||||||
bst_uint offset,
|
bst_uint offset,
|
||||||
bool col_full,
|
|
||||||
std::vector<BaseMaker::SketchEntry> *p_temp) {
|
std::vector<BaseMaker::SketchEntry> *p_temp) {
|
||||||
if (c.length == 0) return;
|
if (c.length == 0) return;
|
||||||
// initialize sbuilder for use
|
// initialize sbuilder for use
|
||||||
@ -531,10 +534,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
const unsigned nid = this->qexpand[i];
|
const unsigned nid = this->qexpand[i];
|
||||||
const unsigned wid = this->node2workindex[nid];
|
const unsigned wid = this->node2workindex[nid];
|
||||||
sbuilder[nid].sum_total = 0.0f;
|
sbuilder[nid].sum_total = 0.0f;
|
||||||
sbuilder[nid].sketch = &sketchs[wid * frealset.size() + offset];
|
sbuilder[nid].sketch = &sketchs[wid * work_set_size + offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!col_full) {
|
|
||||||
// first pass, get sum of weight, TODO, optimization to skip first pass
|
// first pass, get sum of weight, TODO, optimization to skip first pass
|
||||||
for (bst_uint j = 0; j < c.length; ++j) {
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
@ -543,12 +545,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
sbuilder[nid].sum_total += gpair[ridx].hess;
|
sbuilder[nid].sum_total += gpair[ridx].hess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
|
||||||
const unsigned nid = this->qexpand[i];
|
|
||||||
sbuilder[nid].sum_total = static_cast<bst_float>(nstats[nid].sum_hess);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// if only one value, no need to do second pass
|
// if only one value, no need to do second pass
|
||||||
if (c[0].fvalue == c[c.length-1].fvalue) {
|
if (c[0].fvalue == c[c.length-1].fvalue) {
|
||||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
@ -607,8 +603,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
BaseMaker::FMetaHelper feat_helper;
|
BaseMaker::FMetaHelper feat_helper;
|
||||||
// temp space to map feature id to working index
|
// temp space to map feature id to working index
|
||||||
std::vector<int> feat2workindex;
|
std::vector<int> feat2workindex;
|
||||||
// set of index from fset that are real
|
// set of index from fset that are current work set
|
||||||
std::vector<bst_uint> freal_set;
|
std::vector<bst_uint> work_set;
|
||||||
// set of index from that are split candidates.
|
// set of index from that are split candidates.
|
||||||
std::vector<bst_uint> fsplit_set;
|
std::vector<bst_uint> fsplit_set;
|
||||||
// thread temp data
|
// thread temp data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user