hack to make the propose fast in one pass, start sketchmaker
This commit is contained in:
parent
ce7ecadf5e
commit
303f8b9bc5
@ -19,10 +19,8 @@ IUpdater* CreateUpdater(const char *name) {
|
|||||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();
|
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_cqmaker")) return new CQHistMaker<GradStats>();
|
if (!strcmp(name, "grow_cqmaker")) return new CQHistMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_chistmaker")) return new ColumnHistMaker<GradStats>();
|
|
||||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
|
|
||||||
if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >();
|
|
||||||
utils::Error("unknown updater:%s", name);
|
utils::Error("unknown updater:%s", name);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -285,87 +285,6 @@ class HistMaker: public BaseMaker {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename TStats>
|
|
||||||
class ColumnHistMaker: public HistMaker<TStats> {
|
|
||||||
public:
|
|
||||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
|
||||||
IFMatrix *p_fmat,
|
|
||||||
const BoosterInfo &info,
|
|
||||||
const RegTree &tree) {
|
|
||||||
sketchs.resize(tree.param.num_feature);
|
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
|
||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
|
||||||
}
|
|
||||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
|
||||||
while (iter->Next()) {
|
|
||||||
const ColBatch &batch = iter->Value();
|
|
||||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
|
||||||
const bst_uint fid = batch.col_index[i];
|
|
||||||
const ColBatch::Inst &col = batch[i];
|
|
||||||
unsigned nstep = col.length * (this->param.sketch_eps / this->param.sketch_ratio);
|
|
||||||
if (nstep == 0) nstep = 1;
|
|
||||||
for (unsigned i = 0; i < col.length; i += nstep) {
|
|
||||||
sketchs[fid].Push(col[i].fvalue);
|
|
||||||
}
|
|
||||||
if (col.length != 0 && col.length - 1 % nstep != 0) {
|
|
||||||
sketchs[fid].Push(col[col.length-1].fvalue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
|
||||||
// synchronize sketch
|
|
||||||
summary_array.Init(sketchs.size(), max_size);
|
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
|
||||||
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
|
||||||
sketchs[i].GetSummary(&out);
|
|
||||||
summary_array.Set(i, out);
|
|
||||||
}
|
|
||||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
|
||||||
sreducer.AllReduce(&summary_array, n4bytes);
|
|
||||||
// now we get the final result of sketch, setup the cut
|
|
||||||
this->wspace.cut.clear();
|
|
||||||
this->wspace.rptr.clear();
|
|
||||||
this->wspace.rptr.push_back(0);
|
|
||||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
|
||||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
|
||||||
const WXQSketch::Summary a = summary_array[fid];
|
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
|
||||||
bst_float cpt = a.data[i].value - rt_eps;
|
|
||||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
|
||||||
this->wspace.cut.push_back(cpt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// push a value that is greater than anything
|
|
||||||
if (a.size != 0) {
|
|
||||||
bst_float cpt = a.data[a.size - 1].value;
|
|
||||||
// this must be bigger than last value in a scale
|
|
||||||
bst_float last = cpt + fabs(cpt) + rt_eps;
|
|
||||||
this->wspace.cut.push_back(last);
|
|
||||||
}
|
|
||||||
this->wspace.rptr.push_back(this->wspace.cut.size());
|
|
||||||
}
|
|
||||||
// reserve last value for global statistics
|
|
||||||
this->wspace.cut.push_back(0.0f);
|
|
||||||
this->wspace.rptr.push_back(this->wspace.cut.size());
|
|
||||||
}
|
|
||||||
utils::Assert(this->wspace.rptr.size() ==
|
|
||||||
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
|
||||||
"cut space inconsistent");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
|
||||||
// summary array
|
|
||||||
WXQSketch::SummaryArray summary_array;
|
|
||||||
// reducer for summary
|
|
||||||
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
|
||||||
// per feature sketch
|
|
||||||
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename TStats>
|
template<typename TStats>
|
||||||
class CQHistMaker: public HistMaker<TStats> {
|
class CQHistMaker: public HistMaker<TStats> {
|
||||||
protected:
|
protected:
|
||||||
@ -378,7 +297,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
}
|
}
|
||||||
thread_temp.resize(this->get_nthread());
|
thread_temp.resize(this->get_nthread());
|
||||||
|
std::vector<bst_float> root_stats;
|
||||||
|
this->GetRootStats(gpair, *p_fmat, tree, &root_stats);
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
@ -388,7 +308,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
this->MakeSketch(gpair, batch[i], tree, batch.col_index[i],
|
this->MakeSketch(gpair, batch[i], tree,
|
||||||
|
root_stats,
|
||||||
|
batch.col_index[i],
|
||||||
|
p_fmat->GetColDensity(batch.col_index[i]),
|
||||||
&thread_temp[omp_get_thread_num()]);
|
&thread_temp[omp_get_thread_num()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -513,7 +436,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
inline void MakeSketch(const std::vector<bst_gpair> &gpair,
|
inline void MakeSketch(const std::vector<bst_gpair> &gpair,
|
||||||
const ColBatch::Inst &c,
|
const ColBatch::Inst &c,
|
||||||
const RegTree &tree,
|
const RegTree &tree,
|
||||||
|
const std::vector<bst_float> &root_stats,
|
||||||
bst_uint fid,
|
bst_uint fid,
|
||||||
|
float col_density,
|
||||||
std::vector<SketchEntry> *p_temp) {
|
std::vector<SketchEntry> *p_temp) {
|
||||||
if (c.length == 0) return;
|
if (c.length == 0) return;
|
||||||
// initialize sbuilder for use
|
// initialize sbuilder for use
|
||||||
@ -526,6 +451,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
sbuilder[nid].sketch = &sketchs[wid * tree.param.num_feature + fid];
|
sbuilder[nid].sketch = &sketchs[wid * tree.param.num_feature + fid];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (col_density != 1.0f) {
|
||||||
// first pass, get sum of weight, TODO, optimization to skip first pass
|
// first pass, get sum of weight, TODO, optimization to skip first pass
|
||||||
for (bst_uint j = 0; j < c.length; ++j) {
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
@ -534,6 +460,12 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
sbuilder[nid].sum_total += gpair[ridx].hess;
|
sbuilder[nid].sum_total += gpair[ridx].hess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
|
const unsigned nid = this->qexpand[i];
|
||||||
|
sbuilder[nid].sum_total = root_stats[nid];
|
||||||
|
}
|
||||||
|
}
|
||||||
// if only one value, no need to do second pass
|
// if only one value, no need to do second pass
|
||||||
if (c[0].fvalue == c[c.length-1].fvalue) {
|
if (c[0].fvalue == c[c.length-1].fvalue) {
|
||||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
@ -561,6 +493,44 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
sbuilder[nid].Finalize(max_size);
|
sbuilder[nid].Finalize(max_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
inline void GetRootStats(const std::vector<bst_gpair> &gpair,
|
||||||
|
const IFMatrix &fmat,
|
||||||
|
const RegTree &tree,
|
||||||
|
std::vector<float> *p_snode) {
|
||||||
|
std::vector<float> &snode = *p_snode;
|
||||||
|
thread_temp.resize(this->get_nthread());
|
||||||
|
snode.resize(tree.param.num_nodes);
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
const int tid = omp_get_thread_num();
|
||||||
|
thread_temp[tid].resize(tree.param.num_nodes);
|
||||||
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
|
const unsigned nid = this->qexpand[i];
|
||||||
|
thread_temp[tid][nid].sum_total = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
||||||
|
// setup position
|
||||||
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
|
const bst_uint ridx = rowset[i];
|
||||||
|
const int tid = omp_get_thread_num();
|
||||||
|
if (this->position[ridx] < 0) continue;
|
||||||
|
thread_temp[tid][this->position[ridx]].sum_total += gpair[ridx].hess;
|
||||||
|
}
|
||||||
|
// sum the per thread statistics together
|
||||||
|
for (size_t j = 0; j < this->qexpand.size(); ++j) {
|
||||||
|
const int nid = this->qexpand[j];
|
||||||
|
double wsum = 0.0f;
|
||||||
|
for (size_t tid = 0; tid < thread_temp.size(); ++tid) {
|
||||||
|
wsum += thread_temp[tid][nid].sum_total;
|
||||||
|
}
|
||||||
|
// update node statistics
|
||||||
|
snode[nid] = static_cast<bst_float>(wsum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||||
// thread temp data
|
// thread temp data
|
||||||
std::vector< std::vector<SketchEntry> > thread_temp;
|
std::vector< std::vector<SketchEntry> > thread_temp;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user