finish find split, next to do quantile sketch
This commit is contained in:
parent
a7bc769971
commit
dcd0dd5e26
@ -7,6 +7,7 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -75,7 +76,7 @@ class HistMaker: public IUpdater {
|
||||
/*! \brief actual unit pointer */
|
||||
std::vector<unsigned> rptr;
|
||||
/*! \brief cut field */
|
||||
std::vector<unsigned> cut;
|
||||
std::vector<bst_float> cut;
|
||||
// per thread histset
|
||||
std::vector<HistSet> hset;
|
||||
// initialize the hist set
|
||||
@ -106,7 +107,7 @@ class HistMaker: public IUpdater {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
for (size_t tid = 1; tid < hset.size(); ++tid) {
|
||||
hset[0][i].Add(hset[tid][i]);
|
||||
hset[0].data[i].Add(hset[tid].data[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -125,12 +126,34 @@ class HistMaker: public IUpdater {
|
||||
ThreadWSpace wspace;
|
||||
// position of each data
|
||||
std::vector<int> position;
|
||||
/*! \brief queue of nodes to be expanded */
|
||||
std::vector<int> qexpand;
|
||||
/*! \brief map active node to is working index offset in qexpand*/
|
||||
std::vector<int> node2workindex;
|
||||
// reducer for histogram
|
||||
sync::Reducer<TStats> histred;
|
||||
private:
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
RegTree *p_tree) {
|
||||
//this->InitData(gpair, *p_fmat, info.root_index, *p_tree);
|
||||
//this->InitNewNode(qexpand_, gpair, *p_fmat, info, *p_tree);
|
||||
for (int depth = 0; depth < param.max_depth; ++depth) {
|
||||
this->FindSplit(depth, gpair, p_fmat, info, p_tree);
|
||||
//this->ResetPosition(qexpand_, p_fmat, *p_tree);
|
||||
//this->UpdateQueueExpand(*p_tree, &qexpand_);
|
||||
//this->InitNewNode(qexpand_, gpair, *p_fmat, info, *p_tree);
|
||||
// if nothing left to be expand, break
|
||||
if (qexpand.size() == 0) break;
|
||||
}
|
||||
}
|
||||
// create histogram for a setup histset
|
||||
inline void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
unsigned num_feature) {
|
||||
const RegTree &tree) {
|
||||
bst_uint num_feature = tree.param.num_feature;
|
||||
// intialize work space
|
||||
wspace.Init(param);
|
||||
// start accumulating statistics
|
||||
@ -147,20 +170,100 @@ class HistMaker: public IUpdater {
|
||||
const int tid = omp_get_thread_num();
|
||||
HistSet &hset = wspace.hset[tid];
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
int nid = position[ridx];
|
||||
const int nid = position[ridx];
|
||||
if (nid >= 0) {
|
||||
utils::Assert(tree[nid].is_leaf(), "CreateHist happens in leaf");
|
||||
const int wid = node2workindex[nid];
|
||||
for (bst_uint i = 0; i < inst.length; ++i) {
|
||||
utils::Assert(inst[i].index < num_feature, "feature index exceed bound");
|
||||
hset[inst[i].index + nid * num_feature]
|
||||
// feature histogram
|
||||
hset[inst[i].index + wid * (num_feature+1)]
|
||||
.Add(inst[i].fvalue, gpair, info, ridx);
|
||||
}
|
||||
// node histogram, use num_feature to borrow space
|
||||
hset[num_feature + wid * (num_feature + 1)]
|
||||
.data[0].Add(gpair, info, ridx);
|
||||
}
|
||||
}
|
||||
}
|
||||
// accumulating statistics together
|
||||
wspace.Aggregate();
|
||||
// get the split solution
|
||||
}
|
||||
// sync the histogram
|
||||
histred.AllReduce(BeginPtr(wspace.hset[0].data), wspace.hset[0].data.size());
|
||||
}
|
||||
inline void EnumerateSplit(const HistUnit &hist,
|
||||
const TStats &node_sum,
|
||||
bst_uint fid,
|
||||
SplitEntry *best) {
|
||||
double root_gain = node_sum.CalcGain(param);
|
||||
TStats s(param), c(param);
|
||||
for (bst_uint i = 0; i < hist.size; ++i) {
|
||||
s.Add(hist.data[i]);
|
||||
if (s.sum_hess >= param.min_child_weight) {
|
||||
c.SetSubstract(node_sum, s);
|
||||
if (c.sum_hess >= param.min_child_weight) {
|
||||
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
|
||||
best->Update(loss_chg, fid, hist.cut[i], false);
|
||||
}
|
||||
}
|
||||
}
|
||||
s.Clear();
|
||||
for (bst_uint i = hist.size - 1; i != 0; --i) {
|
||||
s.Add(hist.data[i]);
|
||||
if (s.sum_hess >= param.min_child_weight) {
|
||||
c.SetSubstract(node_sum, s);
|
||||
if (c.sum_hess >= param.min_child_weight) {
|
||||
double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain;
|
||||
best->Update(loss_chg, fid, hist.cut[i-1], true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void FindSplit(int depth,
|
||||
const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
RegTree *p_tree) {
|
||||
const bst_uint num_feature = p_tree->param.num_feature;
|
||||
// create histogram
|
||||
this->CreateHist(gpair, p_fmat, info, *p_tree);
|
||||
// get the best split condition for each node
|
||||
std::vector<SplitEntry> sol(qexpand.size());
|
||||
bst_omp_uint nexpand = static_cast<bst_omp_uint>(qexpand.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
|
||||
const int nid = qexpand[wid];
|
||||
utils::Assert(node2workindex[nid] == static_cast<int>(wid), "node2workindex inconsistent");
|
||||
SplitEntry &best = sol[wid];
|
||||
TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
for (bst_uint fid = 0; fid < num_feature; ++ fid) {
|
||||
EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)],
|
||||
node_sum, fid, &best);
|
||||
}
|
||||
}
|
||||
// get the best result, we can synchronize the solution
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) {
|
||||
const int nid = qexpand[wid];
|
||||
const SplitEntry &best = sol[wid];
|
||||
const TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
bst_float weight = node_sum.CalcWeight(param);
|
||||
// set up the values
|
||||
p_tree->stat(nid).loss_chg = best.loss_chg;
|
||||
p_tree->stat(nid).base_weight = weight;
|
||||
p_tree->stat(nid).sum_hess = static_cast<float>(node_sum.sum_hess);
|
||||
node_sum.SetLeafVec(param, p_tree->leafvec(nid));
|
||||
// now we know the solution in snode[nid], set split
|
||||
if (best.loss_chg > rt_eps) {
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].set_split(best.split_index(), best.split_value, best.default_left());
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
(*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0);
|
||||
(*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0);
|
||||
} else {
|
||||
(*p_tree)[nid].set_leaf(weight * param.learning_rate);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user