sorted base sketch maker
This commit is contained in:
parent
5e8e9a9b74
commit
5de0a2cdc0
@ -18,6 +18,7 @@ IUpdater* CreateUpdater(const char *name) {
|
|||||||
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
||||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();
|
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();
|
||||||
|
if (!strcmp(name, "grow_cqmaker")) return new CQHistMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_chistmaker")) return new ColumnHistMaker<GradStats>();
|
if (!strcmp(name, "grow_chistmaker")) return new ColumnHistMaker<GradStats>();
|
||||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
|
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
|
||||||
|
|||||||
@ -51,7 +51,9 @@ class HistMaker: public BaseMaker {
|
|||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
const bst_uint ridx) {
|
const bst_uint ridx) {
|
||||||
unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
|
unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
|
||||||
utils::Assert(i < size, "maximum value must be in cut");
|
utils::Assert(size != 0, "try insert into size=0");
|
||||||
|
utils::Assert(i < size,
|
||||||
|
"maximum value must be in cut, fv = %g, cutmax=%g", fv, cut[size-1]);
|
||||||
data[i].Add(gpair, info, ridx);
|
data[i].Add(gpair, info, ridx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -122,7 +124,7 @@ class HistMaker: public BaseMaker {
|
|||||||
IFMatrix *p_fmat,
|
IFMatrix *p_fmat,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
const RegTree &tree) = 0;
|
const RegTree &tree) = 0;
|
||||||
private:
|
|
||||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||||
IFMatrix *p_fmat,
|
IFMatrix *p_fmat,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
@ -130,6 +132,7 @@ class HistMaker: public BaseMaker {
|
|||||||
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
|
this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
|
||||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||||
this->FindSplit(depth, gpair, p_fmat, info, p_tree);
|
this->FindSplit(depth, gpair, p_fmat, info, p_tree);
|
||||||
|
this->ResetPositionCol(this->qexpand, p_fmat, *p_tree);
|
||||||
this->UpdateQueueExpand(*p_tree);
|
this->UpdateQueueExpand(*p_tree);
|
||||||
// if nothing left to be expand, break
|
// if nothing left to be expand, break
|
||||||
if (qexpand.size() == 0) break;
|
if (qexpand.size() == 0) break;
|
||||||
@ -139,6 +142,8 @@ class HistMaker: public BaseMaker {
|
|||||||
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
|
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
inline void CreateHist(const std::vector<bst_gpair> &gpair,
|
inline void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||||
IFMatrix *p_fmat,
|
IFMatrix *p_fmat,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
@ -166,11 +171,7 @@ class HistMaker: public BaseMaker {
|
|||||||
HistSet &hset = wspace.hset[tid];
|
HistSet &hset = wspace.hset[tid];
|
||||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||||
int nid = position[ridx];
|
int nid = position[ridx];
|
||||||
if (!tree[nid].is_leaf()) {
|
|
||||||
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
|
||||||
}
|
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
|
|
||||||
const int wid = this->node2workindex[nid];
|
const int wid = this->node2workindex[nid];
|
||||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||||
utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
|
utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
|
||||||
@ -365,6 +366,217 @@ class ColumnHistMaker: public HistMaker<TStats> {
|
|||||||
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename TStats>
|
||||||
|
class CQHistMaker: public HistMaker<TStats> {
|
||||||
|
protected:
|
||||||
|
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||||
|
IFMatrix *p_fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
const RegTree &tree) {
|
||||||
|
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
|
||||||
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
|
}
|
||||||
|
std::vector< std::vector<SketchEntry> > stemp;
|
||||||
|
stemp.resize(this->get_nthread());
|
||||||
|
|
||||||
|
// start accumulating statistics
|
||||||
|
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
||||||
|
iter->BeforeFirst();
|
||||||
|
while (iter->Next()) {
|
||||||
|
const ColBatch &batch = iter->Value();
|
||||||
|
// start enumeration
|
||||||
|
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||||
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
|
this->MakeSketch(gpair, batch[i], tree, batch.col_index[i],
|
||||||
|
&stemp[omp_get_thread_num()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// setup maximum size
|
||||||
|
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
||||||
|
// synchronize sketch
|
||||||
|
summary_array.Init(sketchs.size(), max_size);
|
||||||
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
|
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||||
|
sketchs[i].GetSummary(&out);
|
||||||
|
summary_array.Set(i, out);
|
||||||
|
}
|
||||||
|
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
||||||
|
sreducer.AllReduce(&summary_array, n4bytes);
|
||||||
|
// now we get the final result of sketch, setup the cut
|
||||||
|
this->wspace.cut.clear();
|
||||||
|
this->wspace.rptr.clear();
|
||||||
|
this->wspace.rptr.push_back(0);
|
||||||
|
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||||
|
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||||
|
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
||||||
|
for (size_t i = 1; i < a.size; ++i) {
|
||||||
|
bst_float cpt = a.data[i].value - rt_eps;
|
||||||
|
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||||
|
this->wspace.cut.push_back(cpt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// push a value that is greater than anything
|
||||||
|
if (a.size != 0) {
|
||||||
|
bst_float cpt = a.data[a.size - 1].value;
|
||||||
|
// this must be bigger than last value in a scale
|
||||||
|
bst_float last = cpt + fabs(cpt) + rt_eps;
|
||||||
|
this->wspace.cut.push_back(last);
|
||||||
|
}
|
||||||
|
this->wspace.rptr.push_back(this->wspace.cut.size());
|
||||||
|
}
|
||||||
|
// reserve last value for global statistics
|
||||||
|
this->wspace.cut.push_back(0.0f);
|
||||||
|
this->wspace.rptr.push_back(this->wspace.cut.size());
|
||||||
|
}
|
||||||
|
utils::Assert(this->wspace.rptr.size() ==
|
||||||
|
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
||||||
|
"cut space inconsistent");
|
||||||
|
}
|
||||||
|
// temporal space to build a sketch
|
||||||
|
struct SketchEntry {
|
||||||
|
/*! \brief total sum of */
|
||||||
|
bst_float sum_total;
|
||||||
|
/*! \brief statistics used in the sketch */
|
||||||
|
bst_float rmin, wmin;
|
||||||
|
/*! \brief last seen feature value */
|
||||||
|
bst_float last_fvalue;
|
||||||
|
/*! \brief current size of sketch */
|
||||||
|
bst_float next_goal;
|
||||||
|
// pointer to the sketch to put things in
|
||||||
|
utils::WXQuantileSketch<bst_float, bst_float> *sketch;
|
||||||
|
// initialize the space
|
||||||
|
inline void Init(unsigned max_size) {
|
||||||
|
next_goal = 0.0f;
|
||||||
|
rmin = wmin = 0.0f;
|
||||||
|
sketch->temp.Reserve(max_size + 1);
|
||||||
|
sketch->temp.size = 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief push a new element to sketch
|
||||||
|
* \param fvalue feature value, comes in sorted ascending order
|
||||||
|
* \param w weight
|
||||||
|
* \param max_size
|
||||||
|
*/
|
||||||
|
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
|
||||||
|
if (w == 0.0f) return;
|
||||||
|
if (wmin == 0.0f) {
|
||||||
|
last_fvalue = fvalue;
|
||||||
|
wmin = w;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (last_fvalue != fvalue) {
|
||||||
|
bst_float rmax = rmin + wmin;
|
||||||
|
if (rmax >= next_goal) {
|
||||||
|
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
|
||||||
|
// push to sketch
|
||||||
|
sketch->temp.data[sketch->temp.size] =
|
||||||
|
utils::WXQuantileSketch<bst_float, bst_float>::
|
||||||
|
Entry(rmin, rmax, wmin, last_fvalue);
|
||||||
|
utils::Assert(sketch->temp.size < max_size,
|
||||||
|
"invalid maximum size max_size=%u, stemp.size=%lu\n",
|
||||||
|
max_size, sketch->temp.size);
|
||||||
|
++sketch->temp.size;
|
||||||
|
}
|
||||||
|
if (sketch->temp.size == max_size) {
|
||||||
|
next_goal = sum_total * 2.0f + 1e-5f;
|
||||||
|
} else{
|
||||||
|
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rmin = rmax;
|
||||||
|
wmin = w;
|
||||||
|
last_fvalue = fvalue;
|
||||||
|
} else {
|
||||||
|
wmin += w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \brief push final unfinished value to the sketch */
|
||||||
|
inline void Finalize(unsigned max_size) {
|
||||||
|
bst_float rmax = rmin + wmin;
|
||||||
|
//utils::Assert(fabs(rmax - sum_total) < 1e-4 + sum_total * 1e-5,
|
||||||
|
//"invalid sum value, rmax=%f, sum_total=%lf", rmax, sum_total);
|
||||||
|
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
|
||||||
|
utils::Assert(sketch->temp.size <= max_size,
|
||||||
|
"Finalize: invalid maximum size, max_size=%u, stemp.size=%lu",
|
||||||
|
sketch->temp.size, max_size );
|
||||||
|
// push to sketch
|
||||||
|
sketch->temp.data[sketch->temp.size] =
|
||||||
|
utils::WXQuantileSketch<bst_float, bst_float>::
|
||||||
|
Entry(rmin, rmax, wmin, last_fvalue);
|
||||||
|
++sketch->temp.size;
|
||||||
|
}
|
||||||
|
sketch->PushTemp();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
inline void MakeSketch(const std::vector<bst_gpair> &gpair,
|
||||||
|
const ColBatch::Inst &c,
|
||||||
|
const RegTree &tree,
|
||||||
|
bst_uint fid,
|
||||||
|
std::vector<SketchEntry> *p_temp) {
|
||||||
|
if (c.length == 0) return;
|
||||||
|
// initialize sbuilder for use
|
||||||
|
std::vector<SketchEntry> &sbuilder = *p_temp;
|
||||||
|
sbuilder.resize(this->qexpand.size());
|
||||||
|
for (size_t i = 0; i < sbuilder.size(); ++i) {
|
||||||
|
sbuilder[i].sum_total = 0.0f;
|
||||||
|
sbuilder[i].sketch = &sketchs[i * tree.param.num_feature + fid];
|
||||||
|
}
|
||||||
|
// second pass, build the sketch
|
||||||
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
|
const bst_uint ridx = c[j].index;
|
||||||
|
const int nid = this->position[ridx];
|
||||||
|
if (nid >= 0) {
|
||||||
|
const int wid = this->node2workindex[nid];
|
||||||
|
sbuilder[wid].sketch->Push(c[j].fvalue, gpair[ridx].hess);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
// first pass, get sum of weight, TODO, optimization to skip first pass
|
||||||
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
|
const bst_uint ridx = c[j].index;
|
||||||
|
const int nid = this->position[ridx];
|
||||||
|
if (nid >= 0) {
|
||||||
|
const int wid = this->node2workindex[nid];
|
||||||
|
sbuilder[wid].sum_total += gpair[ridx].hess;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if only one value, no need to do second pass
|
||||||
|
if (c[0].fvalue == c[c.length-1].fvalue) {
|
||||||
|
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||||
|
sbuilder[wid].sketch->Push(c[0].fvalue, sbuilder[wid].sum_total);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// two pass scan
|
||||||
|
unsigned max_size = static_cast<unsigned>(this->param.sketch_ratio / this->param.sketch_eps);
|
||||||
|
for (size_t wid = 0; wid < sbuilder.size(); ++wid) {
|
||||||
|
sbuilder[wid].Init(max_size);
|
||||||
|
}
|
||||||
|
// second pass, build the sketch
|
||||||
|
for (bst_uint j = 0; j < c.length; ++j) {
|
||||||
|
const bst_uint ridx = c[j].index;
|
||||||
|
const int nid = this->position[ridx];
|
||||||
|
if (nid >= 0) {
|
||||||
|
const int wid = this->node2workindex[nid];
|
||||||
|
sbuilder[wid].Push(c[j].fvalue, gpair[ridx].hess, max_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t wid = 0; wid < sbuilder.size(); ++wid) {
|
||||||
|
sbuilder[wid].Finalize(max_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||||
|
// summary array
|
||||||
|
WXQSketch::SummaryArray summary_array;
|
||||||
|
// reducer for summary
|
||||||
|
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
||||||
|
// per node, per feature sketch
|
||||||
|
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename TStats>
|
template<typename TStats>
|
||||||
class QuantileHistMaker: public HistMaker<TStats> {
|
class QuantileHistMaker: public HistMaker<TStats> {
|
||||||
@ -374,11 +586,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
// initialize the data structure
|
// initialize the data structure
|
||||||
int nthread;
|
int nthread = BaseMaker::get_nthread();
|
||||||
#pragma omp parallel
|
|
||||||
{
|
|
||||||
nthread = omp_get_num_threads();
|
|
||||||
}
|
|
||||||
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
|
sketchs.resize(this->qexpand.size() * tree.param.num_feature);
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
|
|||||||
@ -638,6 +638,14 @@ class QuantileSketchTemplate {
|
|||||||
inqueue.MakeSummary(&temp);
|
inqueue.MakeSummary(&temp);
|
||||||
// cleanup queue
|
// cleanup queue
|
||||||
inqueue.qtail = 0;
|
inqueue.qtail = 0;
|
||||||
|
this->PushTemp();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inqueue.Push(x, w);
|
||||||
|
}
|
||||||
|
/*! \brief push up temp */
|
||||||
|
inline void PushTemp(void) {
|
||||||
|
temp.Reserve(limit_size * 2);
|
||||||
for (size_t l = 1; true; ++l) {
|
for (size_t l = 1; true; ++l) {
|
||||||
this->InitLevel(l + 1);
|
this->InitLevel(l + 1);
|
||||||
// check if level l is empty
|
// check if level l is empty
|
||||||
@ -658,9 +666,6 @@ class QuantileSketchTemplate {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
inqueue.Push(x, w);
|
|
||||||
}
|
|
||||||
/*! \brief get the summary after finalize */
|
/*! \brief get the summary after finalize */
|
||||||
inline void GetSummary(SummaryContainer *out) {
|
inline void GetSummary(SummaryContainer *out) {
|
||||||
if (level.size() != 0) {
|
if (level.size() != 0) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user