check in two bad ones, start think of column distribut cut row
This commit is contained in:
parent
5061d55725
commit
8ed585a7a2
@ -15,7 +15,8 @@ IUpdater* CreateUpdater(const char *name) {
|
||||
if (!strcmp(name, "prune")) return new TreePruner();
|
||||
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_histmaker")) return new QuantileHistMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_chistmaker")) return new ColumnHistMaker<GradStats>();
|
||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
|
||||
if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >();
|
||||
|
||||
@ -251,7 +251,10 @@ class HistMaker: public IUpdater {
|
||||
const int tid = omp_get_thread_num();
|
||||
HistSet &hset = wspace.hset[tid];
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
const int nid = position[ridx];
|
||||
int nid = position[ridx];
|
||||
if (!tree[nid].is_leaf()) {
|
||||
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
||||
}
|
||||
if (nid >= 0) {
|
||||
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
|
||||
const int wid = this->node2workindex[nid];
|
||||
@ -367,7 +370,88 @@ class HistMaker: public IUpdater {
|
||||
}
|
||||
};
|
||||
|
||||
// hist maker that propose using quantile sketch
|
||||
template<typename TStats>
|
||||
class ColumnHistMaker: public HistMaker<TStats> {
|
||||
public:
|
||||
virtual void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const RegTree &tree) {
|
||||
sketchs.resize(tree.param.num_feature);
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||
}
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const bst_uint fid = batch.col_index[i];
|
||||
const ColBatch::Inst &col = batch[i];
|
||||
unsigned nstep = col.length * (this->param.sketch_eps / this->param.sketch_ratio);
|
||||
if (nstep == 0) nstep = 1;
|
||||
for (unsigned i = 0; i < col.length; i += nstep) {
|
||||
sketchs[fid].Push(col[i].fvalue);
|
||||
}
|
||||
if (col.length != 0 && col.length - 1 % nstep != 0) {
|
||||
sketchs[fid].Push(col[col.length-1].fvalue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
||||
// synchronize sketch
|
||||
summary_array.Init(sketchs.size(), max_size);
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array.Set(i, out);
|
||||
}
|
||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
||||
sreducer.AllReduce(&summary_array, n4bytes);
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
this->wspace.rptr.push_back(0);
|
||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||
const WXQSketch::Summary a = summary_array[fid];
|
||||
for (size_t i = 1; i < a.size; ++i) {
|
||||
bst_float cpt = a.data[i].value - rt_eps;
|
||||
if (i == 1 || cpt > this->wspace.cut.back()) {
|
||||
this->wspace.cut.push_back(cpt);
|
||||
}
|
||||
}
|
||||
// push a value that is greater than anything
|
||||
if (a.size != 0) {
|
||||
bst_float cpt = a.data[a.size - 1].value;
|
||||
// this must be bigger than last value in a scale
|
||||
bst_float last = cpt + fabs(cpt) + rt_eps;
|
||||
this->wspace.cut.push_back(last);
|
||||
}
|
||||
this->wspace.rptr.push_back(this->wspace.cut.size());
|
||||
}
|
||||
// reserve last value for global statistics
|
||||
this->wspace.cut.push_back(0.0f);
|
||||
this->wspace.rptr.push_back(this->wspace.cut.size());
|
||||
}
|
||||
utils::Assert(this->wspace.rptr.size() ==
|
||||
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
||||
"cut space inconsistent");
|
||||
}
|
||||
|
||||
private:
|
||||
typedef utils::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
// summary array
|
||||
WXQSketch::SummaryArray summary_array;
|
||||
// reducer for summary
|
||||
sync::ComplexReducer<WXQSketch::SummaryArray> sreducer;
|
||||
// per feature sketch
|
||||
std::vector< utils::WQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
protected:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user