simplify
This commit is contained in:
parent
5de0a2cdc0
commit
ce7ecadf5e
@ -377,8 +377,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
}
|
}
|
||||||
std::vector< std::vector<SketchEntry> > stemp;
|
thread_temp.resize(this->get_nthread());
|
||||||
stemp.resize(this->get_nthread());
|
|
||||||
|
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
||||||
@ -390,7 +389,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
this->MakeSketch(gpair, batch[i], tree, batch.col_index[i],
|
this->MakeSketch(gpair, batch[i], tree, batch.col_index[i],
|
||||||
&stemp[omp_get_thread_num()]);
|
&thread_temp[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// setup maximum size
|
// setup maximum size
|
||||||
@ -460,7 +459,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
* \param max_size
|
* \param max_size
|
||||||
*/
|
*/
|
||||||
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
|
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
|
||||||
if (w == 0.0f) return;
|
|
||||||
if (wmin == 0.0f) {
|
if (wmin == 0.0f) {
|
||||||
last_fvalue = fvalue;
|
last_fvalue = fvalue;
|
||||||
wmin = w;
|
wmin = w;
|
||||||
@ -520,56 +518,52 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
if (c.length == 0) return;
|
if (c.length == 0) return;
|
||||||
// initialize sbuilder for use
|
// initialize sbuilder for use
|
||||||
std::vector<SketchEntry> &sbuilder = *p_temp;
|
std::vector<SketchEntry> &sbuilder = *p_temp;
|
||||||
sbuilder.resize(this->qexpand.size());
|
sbuilder.resize(tree.param.num_nodes);
|
||||||
for (size_t i = 0; i < sbuilder.size(); ++i) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
sbuilder[i].sum_total = 0.0f;
|
const unsigned nid = this->qexpand[i];
|
||||||
sbuilder[i].sketch = &sketchs[i * tree.param.num_feature + fid];
|
const unsigned wid = this->node2workindex[nid];
|
||||||
|
sbuilder[nid].sum_total = 0.0f;
|
||||||
|
sbuilder[nid].sketch = &sketchs[wid * tree.param.num_feature + fid];
|
||||||
}
|
}
|
||||||
// second pass, build the sketch
|
|
||||||
for (bst_uint j = 0; j < c.length; ++j) {
|
|
||||||
const bst_uint ridx = c[j].index;
|
|
||||||
const int nid = this->position[ridx];
|
|
||||||
if (nid >= 0) {
|
|
||||||
const int wid = this->node2workindex[nid];
|
|
||||||
sbuilder[wid].sketch->Push(c[j].fvalue, gpair[ridx].hess);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
// first pass, get sum of weight, TODO, optimization to skip first pass
|
// first pass, get sum of weight, TODO, optimization to skip first pass
|
||||||
for (bst_uint j = 0; j < c.length; ++j) {
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
const int wid = this->node2workindex[nid];
|
sbuilder[nid].sum_total += gpair[ridx].hess;
|
||||||
sbuilder[wid].sum_total += gpair[ridx].hess;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if only one value, no need to do second pass
|
// if only one value, no need to do second pass
|
||||||
if (c[0].fvalue == c[c.length-1].fvalue) {
|
if (c[0].fvalue == c[c.length-1].fvalue) {
|
||||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
sbuilder[wid].sketch->Push(c[0].fvalue, sbuilder[wid].sum_total);
|
const int nid = this->qexpand[i];
|
||||||
|
sbuilder[nid].sketch->Push(c[0].fvalue, sbuilder[nid].sum_total);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// two pass scan
|
// two pass scan
|
||||||
unsigned max_size = static_cast<unsigned>(this->param.sketch_ratio / this->param.sketch_eps);
|
unsigned max_size = static_cast<unsigned>(this->param.sketch_ratio / this->param.sketch_eps);
|
||||||
for (size_t wid = 0; wid < sbuilder.size(); ++wid) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
sbuilder[wid].Init(max_size);
|
const int nid = this->qexpand[i];
|
||||||
|
sbuilder[nid].Init(max_size);
|
||||||
}
|
}
|
||||||
// second pass, build the sketch
|
// second pass, build the sketch
|
||||||
for (bst_uint j = 0; j < c.length; ++j) {
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
const int wid = this->node2workindex[nid];
|
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size);
|
||||||
sbuilder[wid].Push(c[j].fvalue, gpair[ridx].hess, max_size);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (size_t wid = 0; wid < sbuilder.size(); ++wid) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
sbuilder[wid].Finalize(max_size);
|
const int nid = this->qexpand[i];
|
||||||
|
sbuilder[nid].Finalize(max_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||||
|
// thread temp data
|
||||||
|
std::vector< std::vector<SketchEntry> > thread_temp;
|
||||||
// summary array
|
// summary array
|
||||||
WXQSketch::SummaryArray summary_array;
|
WXQSketch::SummaryArray summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user