fix compile, need final leaf node?
This commit is contained in:
parent
c86b83ea04
commit
daa28f238e
@ -166,7 +166,8 @@ class HistMaker: public IUpdater {
|
|||||||
// initialize temp data structure
|
// initialize temp data structure
|
||||||
inline void InitData(const std::vector<bst_gpair> &gpair,
|
inline void InitData(const std::vector<bst_gpair> &gpair,
|
||||||
const IFMatrix &fmat,
|
const IFMatrix &fmat,
|
||||||
const std::vector<unsigned> &root_index, const RegTree &tree) {
|
const std::vector<unsigned> &root_index,
|
||||||
|
const RegTree &tree) {
|
||||||
utils::Assert(tree.param.num_nodes == tree.param.num_roots,
|
utils::Assert(tree.param.num_nodes == tree.param.num_roots,
|
||||||
"HistMaker: can only grow new tree");
|
"HistMaker: can only grow new tree");
|
||||||
{// setup position
|
{// setup position
|
||||||
@ -271,6 +272,7 @@ class HistMaker: public IUpdater {
|
|||||||
const TStats &node_sum,
|
const TStats &node_sum,
|
||||||
bst_uint fid,
|
bst_uint fid,
|
||||||
SplitEntry *best) {
|
SplitEntry *best) {
|
||||||
|
if (hist.size == 0) return;
|
||||||
double root_gain = node_sum.CalcGain(param);
|
double root_gain = node_sum.CalcGain(param);
|
||||||
TStats s(param), c(param);
|
TStats s(param), c(param);
|
||||||
for (bst_uint i = 0; i < hist.size; ++i) {
|
for (bst_uint i = 0; i < hist.size; ++i) {
|
||||||
@ -319,7 +321,7 @@ class HistMaker: public IUpdater {
|
|||||||
EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)],
|
EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)],
|
||||||
node_sum, fid, &best);
|
node_sum, fid, &best);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// get the best result, we can synchronize the solution
|
// get the best result, we can synchronize the solution
|
||||||
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
|
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
|
||||||
const int nid = qexpand[wid];
|
const int nid = qexpand[wid];
|
||||||
@ -334,7 +336,8 @@ class HistMaker: public IUpdater {
|
|||||||
// now we know the solution in snode[nid], set split
|
// now we know the solution in snode[nid], set split
|
||||||
if (best.loss_chg > rt_eps) {
|
if (best.loss_chg > rt_eps) {
|
||||||
p_tree->AddChilds(nid);
|
p_tree->AddChilds(nid);
|
||||||
(*p_tree)[nid].set_split(best.split_index(), best.split_value, best.default_left());
|
(*p_tree)[nid].set_split(best.split_index(),
|
||||||
|
best.split_value, best.default_left());
|
||||||
// mark right child as 0, to indicate fresh leaf
|
// mark right child as 0, to indicate fresh leaf
|
||||||
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
|
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
|
||||||
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
|
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
|
||||||
@ -379,10 +382,12 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||||
int nid = this->position[ridx];
|
int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
if (tree[nid].is_leaf()) {
|
if (!tree[nid].is_leaf()) {
|
||||||
this->position[ridx] = ~nid;
|
|
||||||
} else {
|
|
||||||
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
this->position[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
||||||
|
}
|
||||||
|
if (this->node2workindex[nid] < 0) {
|
||||||
|
this->position[ridx] = ~nid;
|
||||||
|
} else{
|
||||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||||
builder.AddBudget(inst[j].index, omp_get_thread_num());
|
builder.AddBudget(inst[j].index, omp_get_thread_num());
|
||||||
}
|
}
|
||||||
@ -404,7 +409,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// start putting things into sketch
|
// start putting things into sketch
|
||||||
const bst_omp_uint nfeat = tree.param.num_feature;
|
const bst_omp_uint nfeat = col_ptr.size() - 1;
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
#pragma omp parallel for schedule(dynamic, 1)
|
||||||
for (bst_omp_uint k = 0; k < nfeat; ++k) {
|
for (bst_omp_uint k = 0; k < nfeat; ++k) {
|
||||||
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
|
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
|
||||||
@ -418,15 +423,23 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
size_t max_size = static_cast<size_t>(this->param.sketch_ratio / this->param.sketch_eps);
|
||||||
// synchronize sketch
|
// synchronize sketch
|
||||||
summary_array.Init(sketchs.size(), max_size);
|
summary_array.Init(sketchs.size(), max_size);
|
||||||
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
|
utils::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||||
|
sketchs[i].GetSummary(&out);
|
||||||
|
summary_array.Set(i, out);
|
||||||
|
}
|
||||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
||||||
sreducer.AllReduce(&summary_array, n4bytes);
|
sreducer.AllReduce(&summary_array, n4bytes);
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
this->wspace.cut.clear();
|
||||||
|
this->wspace.rptr.clear();
|
||||||
|
this->wspace.rptr.push_back(0);
|
||||||
|
for (size_t wid = 0; wid < this->qexpand.size(); ++wid) {
|
||||||
for (size_t fid = 0; fid < tree.param.num_feature; ++fid) {
|
for (size_t fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||||
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid];
|
||||||
for (size_t i = 0; i < a.size; ++i) {
|
for (size_t i = 0; i < a.size; ++i) {
|
||||||
bst_float cpt = a.data[i].value + rt_eps;
|
bst_float cpt = a.data[i].value + rt_eps;
|
||||||
if (i == 0 || cpt > this->wspace.cut.back()){
|
if (i == 0 || cpt > this->wspace.cut.back()) {
|
||||||
this->wspace.cut.push_back(cpt);
|
this->wspace.cut.push_back(cpt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -437,7 +450,8 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
this->wspace.rptr.push_back(this->wspace.cut.size());
|
this->wspace.rptr.push_back(this->wspace.cut.size());
|
||||||
}
|
}
|
||||||
utils::Assert(this->wspace.rptr.size() ==
|
utils::Assert(this->wspace.rptr.size() ==
|
||||||
(tree.param.num_feature + 1) * this->qexpand.size(), "cut space inconsistent");
|
(tree.param.num_feature + 1) * this->qexpand.size() + 1,
|
||||||
|
"cut space inconsistent");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -258,7 +258,7 @@ struct WXQSummary : public WQSummary<DType, RType> {
|
|||||||
return e.rmin_next() > e.rmax_prev() + chunk;
|
return e.rmin_next() > e.rmax_prev() + chunk;
|
||||||
}
|
}
|
||||||
// set prune
|
// set prune
|
||||||
inline void SetPrune(const WXQSummary &src, RType maxsize) {
|
inline void SetPrune(const WQSummary<DType, RType> &src, RType maxsize) {
|
||||||
if (src.size <= maxsize) {
|
if (src.size <= maxsize) {
|
||||||
this->CopyFrom(src); return;
|
this->CopyFrom(src); return;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user