This commit is contained in:
tqchen 2014-11-15 21:18:15 -08:00
parent daa28f238e
commit 02c2278f96

View File

@ -162,6 +162,10 @@ class HistMaker: public IUpdater {
// if nothing left to be expand, break // if nothing left to be expand, break
if (qexpand.size() == 0) break; if (qexpand.size() == 0) break;
} }
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
(*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
}
} }
// initialize temp data structure // initialize temp data structure
inline void InitData(const std::vector<bst_gpair> &gpair, inline void InitData(const std::vector<bst_gpair> &gpair,
@ -271,7 +275,8 @@ class HistMaker: public IUpdater {
inline void EnumerateSplit(const HistUnit &hist, inline void EnumerateSplit(const HistUnit &hist,
const TStats &node_sum, const TStats &node_sum,
bst_uint fid, bst_uint fid,
SplitEntry *best) { SplitEntry *best,
TStats *left_sum) {
if (hist.size == 0) return; if (hist.size == 0) return;
double root_gain = node_sum.CalcGain(param); double root_gain = node_sum.CalcGain(param);
TStats s(param), c(param); TStats s(param), c(param);
@ -281,7 +286,9 @@ class HistMaker: public IUpdater {
c.SetSubstract(node_sum, s); c.SetSubstract(node_sum, s);
if (c.sum_hess >= param.min_child_weight) { if (c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, hist.cut[i], false); if (best->Update(loss_chg, fid, hist.cut[i], false)) {
*left_sum = s;
}
} }
} }
} }
@ -292,7 +299,9 @@ class HistMaker: public IUpdater {
c.SetSubstract(node_sum, s); c.SetSubstract(node_sum, s);
if (c.sum_hess >= param.min_child_weight) { if (c.sum_hess >= param.min_child_weight) {
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
best->Update(loss_chg, fid, hist.cut[i-1], true); if (best->Update(loss_chg, fid, hist.cut[i-1], true)) {
*left_sum = c;
}
} }
} }
} }
@ -309,6 +318,7 @@ class HistMaker: public IUpdater {
this->CreateHist(gpair, p_fmat, info, *p_tree); this->CreateHist(gpair, p_fmat, info, *p_tree);
// get the best split condition for each node // get the best split condition for each node
std::vector<SplitEntry> sol(qexpand.size()); std::vector<SplitEntry> sol(qexpand.size());
std::vector<TStats> left_sum(qexpand.size());
bst_omp_uint nexpand = static_cast<bst_omp_uint>(qexpand.size()); bst_omp_uint nexpand = static_cast<bst_omp_uint>(qexpand.size());
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
@ -319,7 +329,7 @@ class HistMaker: public IUpdater {
TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0]; TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
for (bst_uint fid = 0; fid < num_feature; ++ fid) { for (bst_uint fid = 0; fid < num_feature; ++ fid) {
EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)], EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)],
node_sum, fid, &best); node_sum, fid, &best, &left_sum[wid]);
} }
} }
// get the best result, we can synchronize the solution // get the best result, we can synchronize the solution
@ -327,12 +337,9 @@ class HistMaker: public IUpdater {
const int nid = qexpand[wid]; const int nid = qexpand[wid];
const SplitEntry &best = sol[wid]; const SplitEntry &best = sol[wid];
const TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0]; const TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
bst_float weight = node_sum.CalcWeight(param); this->SetStats(p_tree, nid, node_sum);
// set up the values // set up the values
p_tree->stat(nid).loss_chg = best.loss_chg; p_tree->stat(nid).loss_chg = best.loss_chg;
p_tree->stat(nid).base_weight = weight;
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
// now we know the solution in snode[nid], set split // now we know the solution in snode[nid], set split
if (best.loss_chg > rt_eps) { if (best.loss_chg > rt_eps) {
p_tree->AddChilds(nid); p_tree->AddChilds(nid);
@ -341,11 +348,22 @@ class HistMaker: public IUpdater {
// mark right child as 0, to indicate fresh leaf // mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
// right side sum
TStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum);
} else { } else {
(*p_tree)[nid].set_leaf(weight * param.learning_rate); (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate);
} }
} }
} }
inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) {
p_tree->stat(nid).base_weight = node_sum.CalcWeight(param);
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
}
}; };
// hist maker that propose using quantile sketch // hist maker that propose using quantile sketch