This commit is contained in:
tqchen 2014-11-15 21:18:15 -08:00
parent daa28f238e
commit 02c2278f96

View File

@ -161,7 +161,11 @@ class HistMaker: public IUpdater {
this->UpdateNode2WorkIndex(*p_tree);
// if nothing left to be expand, break
if (qexpand.size() == 0) break;
}
}
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
}
// initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair,
@ -271,7 +275,8 @@ class HistMaker: public IUpdater {
inline void EnumerateSplit(const HistUnit &hist,
const TStats &node_sum,
bst_uint fid,
SplitEntry *best) {
SplitEntry *best,
TStats *left_sum) {
if (hist.size == 0) return;
double root_gain = node_sum.CalcGain(param);
TStats s(param), c(param);
@ -281,7 +286,9 @@ class HistMaker: public IUpdater {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, hist.cut[i], false);
if (best->Update(loss_chg, fid, hist.cut[i], false)) {
*left_sum = s;
}
}
}
}
@ -292,7 +299,9 @@ class HistMaker: public IUpdater {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, hist.cut[i-1], true);
if (best->Update(loss_chg, fid, hist.cut[i-1], true)) {
*left_sum = c;
}
}
}
}
@ -309,17 +318,18 @@ class HistMaker: public IUpdater {
this->CreateHist(gpair, p_fmat, info, *p_tree);
// get the best split condition for each node
std::vector<SplitEntry> sol(qexpand.size());
std::vector<TStats> left_sum(qexpand.size());
bst_omp_uint nexpand = static_cast<bst_omp_uint>(qexpand.size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
const int nid = qexpand[wid];
utils::Assert(node2workindex[nid] == static_cast<int>(wid),
"node2workindex inconsistent");
SplitEntry &best = sol[wid];
SplitEntry &best = sol[wid];
TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
for (bst_uint fid = 0; fid < num_feature; ++ fid) {
EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)],
node_sum, fid, &best);
node_sum, fid, &best, &left_sum[wid]);
}
}
// get the best result, we can synchronize the solution
@ -327,25 +337,33 @@ class HistMaker: public IUpdater {
const int nid = qexpand[wid];
const SplitEntry &best = sol[wid];
const TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
bst_float weight = node_sum.CalcWeight(param);
this->SetStats(p_tree, nid, node_sum);
// set up the values
p_tree->stat(nid).loss_chg = best.loss_chg;
p_tree->stat(nid).base_weight = weight;
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
// now we know the solution in snode[nid], set split
if (best.loss_chg > rt_eps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].set_split(best.split_index(),
best.split_value, best.default_left());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
// right side sum
TStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum);
} else {
(*p_tree)[nid].set_leaf(weight * param.learning_rate);
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
}
}
inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) {
p_tree->stat(nid).base_weight = node_sum.CalcWeight(param);
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
}
};
// hist maker that propose using quantile sketch