sorted base sketch maker

This commit is contained in:
tqchen 2014-11-18 10:19:18 -08:00
parent 5e8e9a9b74
commit 5de0a2cdc0
4 changed files with 248 additions and 34 deletions

View File

@ -18,6 +18,7 @@ IUpdater* CreateUpdater(const char *name) {
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>(); if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>(); if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>(); if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();
if (!strcmp(name, "grow_cqmaker")) return new CQHistMaker<GradStats>();
if (!strcmp(name, "grow_chistmaker")) return new ColumnHistMaker<GradStats>(); if (!strcmp(name, "grow_chistmaker")) return new ColumnHistMaker<GradStats>();
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>(); if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >(); if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();

View File

@ -51,7 +51,9 @@ class HistMaker: public BaseMaker {
const BoosterInfo &info, const BoosterInfo &info,
const bst_uint ridx) { const bst_uint ridx) {
unsigned i = std::upper_bound(cut, cut + size, fv) - cut; unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
utils::Assert(i < size, "maximum value must be in cut"); utils::Assert(size != 0, "try insert into size=0");
utils::Assert(i < size,
"maximum value must be in cut, fv = %g, cutmax=%g", fv, cut[size-1]);
data[i].Add(gpair, info, ridx); data[i].Add(gpair, info, ridx);
} }
}; };
@ -122,7 +124,7 @@ class HistMaker: public BaseMaker {
IFMatrix *p_fmat, IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
const RegTree &tree) = 0; const RegTree &tree) = 0;
private:
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
@ -130,6 +132,7 @@ class HistMaker: public BaseMaker {
this->InitData(gpair, *p_fmat, info.root_index, *p_tree); this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) { for (int depth = 0; depth < param.max_depth; ++depth) {
this->FindSplit(depth, gpair, p_fmat, info, p_tree); this->FindSplit(depth, gpair, p_fmat, info, p_tree);
this->ResetPositionCol(this->qexpand, p_fmat, *p_tree);
this->UpdateQueueExpand(*p_tree); this->UpdateQueueExpand(*p_tree);
// if nothing left to be expand, break // if nothing left to be expand, break
if (qexpand.size() == 0) break; if (qexpand.size() == 0) break;
@ -139,6 +142,8 @@ class HistMaker: public BaseMaker {
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
} }
} }
private:
inline void CreateHist(const std::vector<bst_gpair> &gpair, inline void CreateHist(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
@ -166,11 +171,7 @@ class HistMaker: public BaseMaker {
HistSet &hset = wspace.hset[tid]; HistSet &hset = wspace.hset[tid];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i); const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
int nid = position[ridx]; int nid = position[ridx];
if (!tree[nid].is_leaf()) {
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
}
if (nid >= 0) { if (nid >= 0) {
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
const int wid = this->node2workindex[nid]; const int wid = this->node2workindex[nid];
for (bst_uint i = 0; i < inst.length; ++i) { for (bst_uint i = 0; i < inst.length; ++i) {
utils::Assert(inst[i].index < num_feature, "feature index exceed bound"); utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
@ -365,6 +366,217 @@ class ColumnHistMaker: public HistMaker<TStats> {
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs; std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
}; };
template<typename TStats>
class CQHistMaker: public HistMaker<TStats> {
protected:
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info,
const RegTree &tree) {
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps);
}
std::vector< std::vector<SketchEntry> > stemp;
stemp.resize(this->get_nthread());
// start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
this->MakeSketch(gpair, batch[i], tree, batch.col_index[i],
&stemp[omp_get_thread_num()]);
}
}
// setup maximum size
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
// synchronize sketch
summary_array.Init(sketchs.size(), max_size);
for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array.Set(i, out);
}
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
sreducer.AllReduce(&summary_array, n4bytes);
// now we get the final result of sketch, setup the cut
this->wspace.cut.clear();
this->wspace.rptr.clear();
this->wspace.rptr.push_back(0);
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = a.data[i].value - rt_eps;
if (i == 1 || cpt > this->wspace.cut.back()) {
this->wspace.cut.push_back(cpt);
}
}
// push a value that is greater than anything
if (a.size != 0) {
bst_float cpt = a.data[a.size - 1].value;
// this must be bigger than last value in a scale
bst_float last = cpt + fabs(cpt) + rt_eps;
this->wspace.cut.push_back(last);
}
this->wspace.rptr.push_back(this->wspace.cut.size());
}
// reserve last value for global statistics
this->wspace.cut.push_back(0.0f);
this->wspace.rptr.push_back(this->wspace.cut.size());
}
utils::Assert(this->wspace.rptr.size() ==
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
"cut space inconsistent");
}
// temporal space to build a sketch
struct SketchEntry {
/*! \brief total sum of */
bst_float sum_total;
/*! \brief statistics used in the sketch */
bst_float rmin, wmin;
/*! \brief last seen feature value */
bst_float last_fvalue;
/*! \brief current size of sketch */
bst_float next_goal;
// pointer to the sketch to put things in
utils::WXQuantileSketch<bst_float, bst_float> *sketch;
// initialize the space
inline void Init(unsigned max_size) {
next_goal = 0.0f;
rmin = wmin = 0.0f;
sketch->temp.Reserve(max_size + 1);
sketch->temp.size = 0;
}
/*!
* \brief push a new element to sketch
* \param fvalue feature value, comes in sorted ascending order
* \param w weight
* \param max_size
*/
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
if (w == 0.0f) return;
if (wmin == 0.0f) {
last_fvalue = fvalue;
wmin = w;
return;
}
if (last_fvalue != fvalue) {
bst_float rmax = rmin + wmin;
if (rmax >= next_goal) {
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
// push to sketch
sketch->temp.data[sketch->temp.size] =
utils::WXQuantileSketch<bst_float, bst_float>::
Entry(rmin, rmax, wmin, last_fvalue);
utils::Assert(sketch->temp.size < max_size,
"invalid maximum size max_size=%u, stemp.size=%lu\n",
max_size, sketch->temp.size);
++sketch->temp.size;
}
if (sketch->temp.size == max_size) {
next_goal = sum_total * 2.0f + 1e-5f;
} else{
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
}
}
rmin = rmax;
wmin = w;
last_fvalue = fvalue;
} else {
wmin += w;
}
}
/*! \brief push final unfinished value to the sketch */
inline void Finalize(unsigned max_size) {
bst_float rmax = rmin + wmin;
//utils::Assert(fabs(rmax - sum_total) < 1e-4 + sum_total * 1e-5,
//"invalid sum value, rmax=%f, sum_total=%lf", rmax, sum_total);
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
utils::Assert(sketch->temp.size <= max_size,
"Finalize: invalid maximum size, max_size=%u, stemp.size=%lu",
sketch->temp.size, max_size );
// push to sketch
sketch->temp.data[sketch->temp.size] =
utils::WXQuantileSketch<bst_float, bst_float>::
Entry(rmin, rmax, wmin, last_fvalue);
++sketch->temp.size;
}
sketch->PushTemp();
}
};
private:
inline void MakeSketch(const std::vector<bst_gpair> &gpair,
const ColBatch::Inst &c,
const RegTree &tree,
bst_uint fid,
std::vector<SketchEntry> *p_temp) {
if (c.length == 0) return;
// initialize sbuilder for use
std::vector<SketchEntry> &sbuilder = *p_temp;
sbuilder.resize(this->qexpand.size());
for (size_t i = 0; i < sbuilder.size(); ++i) {
sbuilder[i].sum_total = 0.0f;
sbuilder[i].sketch = &sketchs[i * tree.param.num_feature + fid];
}
// second pass, build the sketch
for (bst_uint j = 0; j < c.length; ++j) {
const bst_uint ridx = c[j].index;
const int nid = this->position[ridx];
if (nid >= 0) {
const int wid = this->node2workindex[nid];
sbuilder[wid].sketch->Push(c[j].fvalue, gpair[ridx].hess);
}
}
return;
// first pass, get sum of weight, TODO, optimization to skip first pass
for (bst_uint j = 0; j < c.length; ++j) {
const bst_uint ridx = c[j].index;
const int nid = this->position[ridx];
if (nid >= 0) {
const int wid = this->node2workindex[nid];
sbuilder[wid].sum_total += gpair[ridx].hess;
}
}
// if only one value, no need to do second pass
if (c[0].fvalue == c[c.length-1].fvalue) {
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
sbuilder[wid].sketch->Push(c[0].fvalue, sbuilder[wid].sum_total);
}
return;
}
// two pass scan
unsigned max_size = static_cast<unsigned>(this->param.sketch_ratio / this->param.sketch_eps);
for (size_t wid = 0; wid < sbuilder.size(); ++wid) {
sbuilder[wid].Init(max_size);
}
// second pass, build the sketch
for (bst_uint j = 0; j < c.length; ++j) {
const bst_uint ridx = c[j].index;
const int nid = this->position[ridx];
if (nid >= 0) {
const int wid = this->node2workindex[nid];
sbuilder[wid].Push(c[j].fvalue, gpair[ridx].hess, max_size);
}
}
for (size_t wid = 0; wid < sbuilder.size(); ++wid) {
sbuilder[wid].Finalize(max_size);
}
}
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
// summary array
WXQSketch::SummaryArray summary_array;
// reducer for summary
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
// per node, per feature sketch
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
};
template<typename TStats> template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> { class QuantileHistMaker: public HistMaker<TStats> {
@ -374,11 +586,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
const BoosterInfo &info, const BoosterInfo &info,
const RegTree &tree) { const RegTree &tree) {
// initialize the data structure // initialize the data structure
int nthread; int nthread = BaseMaker::get_nthread();
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
sketchs.resize(this->qexpand.size() * tree.param.num_feature); sketchs.resize(this->qexpand.size() * tree.param.num_feature);
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps); sketchs[i].Init(info.num_row, this->param.sketch_eps);

View File

@ -638,29 +638,34 @@ class QuantileSketchTemplate {
inqueue.MakeSummary(&temp); inqueue.MakeSummary(&temp);
// cleanup queue // cleanup queue
inqueue.qtail = 0; inqueue.qtail = 0;
for (size_t l = 1; true; ++l) { this->PushTemp();
this->InitLevel(l + 1);
// check if level l is empty
if (level[l].size == 0) {
level[l].SetPrune(temp, limit_size);
break;
} else {
// level 0 is actually temp space
level[0].SetPrune(temp, limit_size);
temp.SetCombine(level[0], level[l]);
if (temp.size > limit_size) {
// try next level
level[l].size = 0;
} else {
// if merged record is still smaller, no need to send to next level
level[l].CopyFrom(temp); break;
}
}
}
} }
} }
inqueue.Push(x, w); inqueue.Push(x, w);
} }
/*! \brief push up temp */
inline void PushTemp(void) {
temp.Reserve(limit_size * 2);
for (size_t l = 1; true; ++l) {
this->InitLevel(l + 1);
// check if level l is empty
if (level[l].size == 0) {
level[l].SetPrune(temp, limit_size);
break;
} else {
// level 0 is actually temp space
level[0].SetPrune(temp, limit_size);
temp.SetCombine(level[0], level[l]);
if (temp.size > limit_size) {
// try next level
level[l].size = 0;
} else {
// if merged record is still smaller, no need to send to next level
level[l].CopyFrom(temp); break;
}
}
}
}
/*! \brief get the summary after finalize */ /*! \brief get the summary after finalize */
inline void GetSummary(SummaryContainer *out) { inline void GetSummary(SummaryContainer *out) {
if (level.size() != 0) { if (level.size() != 0) {