updated base
This commit is contained in:
parent
8874234e5e
commit
5e8e9a9b74
@ -27,8 +27,8 @@ class BaseMaker: public IUpdater {
|
|||||||
protected:
|
protected:
|
||||||
// ------static helper functions ------
|
// ------static helper functions ------
|
||||||
// helper function to get to next level of the tree
|
// helper function to get to next level of the tree
|
||||||
// must work on non-leaf node
|
/*! \brief this is helper function for row based data*/
|
||||||
inline static int NextLevel(const SparseBatch::Inst &inst, const RegTree &tree, int nid) {
|
inline static int NextLevel(const RowBatch::Inst &inst, const RegTree &tree, int nid) {
|
||||||
const RegTree::Node &n = tree[nid];
|
const RegTree::Node &n = tree[nid];
|
||||||
bst_uint findex = n.split_index();
|
bst_uint findex = n.split_index();
|
||||||
for (unsigned i = 0; i < inst.length; ++i) {
|
for (unsigned i = 0; i < inst.length; ++i) {
|
||||||
@ -52,19 +52,6 @@ class BaseMaker: public IUpdater {
|
|||||||
return nthread;
|
return nthread;
|
||||||
}
|
}
|
||||||
// ------class member helpers---------
|
// ------class member helpers---------
|
||||||
// return decoded position
|
|
||||||
inline int DecodePosition(bst_uint ridx) const{
|
|
||||||
const int pid = position[ridx];
|
|
||||||
return pid < 0 ? ~pid : pid;
|
|
||||||
}
|
|
||||||
// encode the encoded position value for ridx
|
|
||||||
inline void SetEncodePosition(bst_uint ridx, int nid) {
|
|
||||||
if (position[ridx] < 0) {
|
|
||||||
position[ridx] = ~nid;
|
|
||||||
} else {
|
|
||||||
position[ridx] = nid;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*! \brief initialize temp data structure */
|
/*! \brief initialize temp data structure */
|
||||||
inline void InitData(const std::vector<bst_gpair> &gpair,
|
inline void InitData(const std::vector<bst_gpair> &gpair,
|
||||||
const IFMatrix &fmat,
|
const IFMatrix &fmat,
|
||||||
@ -117,6 +104,99 @@ class BaseMaker: public IUpdater {
|
|||||||
qexpand = newnodes;
|
qexpand = newnodes;
|
||||||
this->UpdateNode2WorkIndex(tree);
|
this->UpdateNode2WorkIndex(tree);
|
||||||
}
|
}
|
||||||
|
// return decoded position
|
||||||
|
inline int DecodePosition(bst_uint ridx) const{
|
||||||
|
const int pid = position[ridx];
|
||||||
|
return pid < 0 ? ~pid : pid;
|
||||||
|
}
|
||||||
|
// encode the encoded position value for ridx
|
||||||
|
inline void SetEncodePosition(bst_uint ridx, int nid) {
|
||||||
|
if (position[ridx] < 0) {
|
||||||
|
position[ridx] = ~nid;
|
||||||
|
} else {
|
||||||
|
position[ridx] = nid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief this is helper function uses column based data structure,
|
||||||
|
* reset the positions to the lastest one
|
||||||
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
|
* \param p_fmat feature matrix needed for tree construction
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
inline void ResetPositionCol(const std::vector<int> &nodes, IFMatrix *p_fmat, const RegTree &tree) {
|
||||||
|
// set the positions in the nondefault
|
||||||
|
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
|
||||||
|
// set rest of instances to default position
|
||||||
|
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||||
|
// set default direct nodes to default
|
||||||
|
// for leaf nodes that are not fresh, mark then to ~nid,
|
||||||
|
// so that they are ignored in future statistics collection
|
||||||
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
|
const bst_uint ridx = rowset[i];
|
||||||
|
const int nid = this->DecodePosition(ridx);
|
||||||
|
if (tree[nid].is_leaf()) {
|
||||||
|
// mark finish when it is not a fresh leaf
|
||||||
|
if (tree[nid].cright() == -1) {
|
||||||
|
position[ridx] = ~nid;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// push to default branch
|
||||||
|
if (tree[nid].default_left()) {
|
||||||
|
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||||
|
} else {
|
||||||
|
this->SetEncodePosition(ridx, tree[nid].cright());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief this is helper function uses column based data structure,
|
||||||
|
* update all positions into nondefault branch, if any, ignore the default branch
|
||||||
|
* \param nodes the set of nodes that contains the split to be used
|
||||||
|
* \param p_fmat feature matrix needed for tree construction
|
||||||
|
* \param tree the regression tree structure
|
||||||
|
*/
|
||||||
|
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
|
||||||
|
IFMatrix *p_fmat, const RegTree &tree) {
|
||||||
|
// step 1, classify the non-default data into right places
|
||||||
|
std::vector<unsigned> fsplits;
|
||||||
|
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||||
|
const int nid = nodes[i];
|
||||||
|
if (!tree[nid].is_leaf()) {
|
||||||
|
fsplits.push_back(tree[nid].split_index());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort(fsplits.begin(), fsplits.end());
|
||||||
|
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||||
|
|
||||||
|
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fsplits);
|
||||||
|
while (iter->Next()) {
|
||||||
|
const ColBatch &batch = iter->Value();
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
ColBatch::Inst col = batch[i];
|
||||||
|
const bst_uint fid = batch.col_index[i];
|
||||||
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||||
|
const bst_uint ridx = col[j].index;
|
||||||
|
const float fvalue = col[j].fvalue;
|
||||||
|
const int nid = this->DecodePosition(ridx);
|
||||||
|
// go back to parent, correct those who are not default
|
||||||
|
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
|
||||||
|
if(fvalue < tree[nid].split_cond()) {
|
||||||
|
this->SetEncodePosition(ridx, tree[nid].cleft());
|
||||||
|
} else {
|
||||||
|
this->SetEncodePosition(ridx, tree[nid].cright());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/*! \brief training parameter of tree grower */
|
/*! \brief training parameter of tree grower */
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
/*! \brief queue of nodes to be expanded */
|
/*! \brief queue of nodes to be expanded */
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user