Merge branch 'unity' of ssh://github.com/tqchen/xgboost into unity
Conflicts: src/learner/evaluation-inl.hpp
This commit is contained in:
@@ -244,6 +244,7 @@ struct CVGradStats : public GradStats {
|
||||
}
|
||||
/*! \brief calculate gain of the solution */
|
||||
inline double CalcGain(const TrainParam ¶m) const {
|
||||
return param.CalcGain(train[0].sum_grad, train[0].sum_hess);
|
||||
double ret = 0.0;
|
||||
for (unsigned i = 0; i < vsize; ++i) {
|
||||
ret += param.CalcGain(train[i].sum_grad,
|
||||
|
||||
@@ -63,7 +63,7 @@ inline IUpdater<FMatrix>* CreateUpdater(const char *name) {
|
||||
if (!strcmp(name, "refresh")) return new TreeRefresher<FMatrix, GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<FMatrix, GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker2")) return new ColMaker<FMatrix, CVGradStats<2> >();
|
||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker<FMatrix, CVGradStats<5> >();
|
||||
// if (!strcmp(name, "grow_colmaker5")) return new ColMaker<FMatrix, CVGradStats<5> >();
|
||||
utils::Error("unknown updater:%s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -203,8 +203,8 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
}
|
||||
// update node statistics
|
||||
snode[nid].stats = stats;
|
||||
snode[nid].root_gain = stats.CalcGain(param);
|
||||
snode[nid].weight = stats.CalcWeight(param);
|
||||
snode[nid].root_gain = static_cast<float>(stats.CalcGain(param));
|
||||
snode[nid].weight = static_cast<float>(stats.CalcWeight(param));
|
||||
}
|
||||
}
|
||||
/*! \brief update queue expand add in new leaves */
|
||||
@@ -251,7 +251,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
||||
c.SetSubstract(snode[nid].stats, e.stats);
|
||||
if (c.sum_hess >= param.min_child_weight) {
|
||||
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
|
||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
||||
}
|
||||
}
|
||||
@@ -266,7 +266,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
ThreadEntry &e = temp[nid];
|
||||
c.SetSubstract(snode[nid].stats, e.stats);
|
||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
||||
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||
bst_float loss_chg = static_cast<bst_float>(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain);
|
||||
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search);
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
||||
for (unsigned i = 0; i < nbatch; ++i) {
|
||||
SparseBatch::Inst inst = batch[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
const size_t ridx = batch.base_rowid + i;
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = fvec_temp[tid];
|
||||
feats.Fill(inst);
|
||||
for (size_t j = 0; j < trees.size(); ++j) {
|
||||
@@ -112,16 +112,16 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
||||
inline void Refresh(const std::vector<TStats> &gstats,
|
||||
int nid, RegTree *p_tree) {
|
||||
RegTree &tree = *p_tree;
|
||||
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
||||
tree.stat(nid).base_weight = static_cast<float>(gstats[nid].CalcWeight(param));
|
||||
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
||||
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
|
||||
if (tree[nid].is_leaf()) {
|
||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||
} else {
|
||||
tree.stat(nid).loss_chg =
|
||||
tree.stat(nid).loss_chg = static_cast<float>(
|
||||
gstats[tree[nid].cleft()].CalcGain(param) +
|
||||
gstats[tree[nid].cright()].CalcGain(param) -
|
||||
gstats[nid].CalcGain(param);
|
||||
gstats[nid].CalcGain(param));
|
||||
this->Refresh(gstats, tree[nid].cleft(), p_tree);
|
||||
this->Refresh(gstats, tree[nid].cright(), p_tree);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user