updated base

This commit is contained in:
tqchen 2014-11-17 10:49:53 -08:00
parent 8874234e5e
commit 5e8e9a9b74

View File

@ -27,8 +27,8 @@ class BaseMaker: public IUpdater {
protected: protected:
// ------static helper functions ------ // ------static helper functions ------
// helper function to get to next level of the tree // helper function to get to next level of the tree
// must work on non-leaf node /*! \brief this is helper function for row based data*/
inline static int NextLevel(const SparseBatch::Inst &inst, const RegTree &tree, int nid) { inline static int NextLevel(const RowBatch::Inst &inst, const RegTree &tree, int nid) {
const RegTree::Node &n = tree[nid]; const RegTree::Node &n = tree[nid];
bst_uint findex = n.split_index(); bst_uint findex = n.split_index();
for (unsigned i = 0; i < inst.length; ++i) { for (unsigned i = 0; i < inst.length; ++i) {
@ -52,19 +52,6 @@ class BaseMaker: public IUpdater {
return nthread; return nthread;
} }
// ------class member helpers--------- // ------class member helpers---------
// return decoded position
inline int DecodePosition(bst_uint ridx) const{
const int pid = position[ridx];
return pid < 0 ? ~pid : pid;
}
// encode the encoded position value for ridx
inline void SetEncodePosition(bst_uint ridx, int nid) {
if (position[ridx] < 0) {
position[ridx] = ~nid;
} else {
position[ridx] = nid;
}
}
/*! \brief initialize temp data structure */ /*! \brief initialize temp data structure */
inline void InitData(const std::vector<bst_gpair> &gpair, inline void InitData(const std::vector<bst_gpair> &gpair,
const IFMatrix &fmat, const IFMatrix &fmat,
@ -117,6 +104,99 @@ class BaseMaker: public IUpdater {
qexpand = newnodes; qexpand = newnodes;
this->UpdateNode2WorkIndex(tree); this->UpdateNode2WorkIndex(tree);
} }
// return decoded position
inline int DecodePosition(bst_uint ridx) const{
const int pid = position[ridx];
return pid < 0 ? ~pid : pid;
}
// encode the encoded position value for ridx
inline void SetEncodePosition(bst_uint ridx, int nid) {
if (position[ridx] < 0) {
position[ridx] = ~nid;
} else {
position[ridx] = nid;
}
}
/*!
* \brief this is helper function uses column based data structure,
* reset the positions to the lastest one
* \param nodes the set of nodes that contains the split to be used
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
*/
inline void ResetPositionCol(const std::vector<int> &nodes, IFMatrix *p_fmat, const RegTree &tree) {
// set the positions in the nondefault
this->SetNonDefaultPositionCol(nodes, p_fmat, tree);
// set rest of instances to default position
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// set default direct nodes to default
// for leaf nodes that are not fresh, mark then to ~nid,
// so that they are ignored in future statistics collection
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
const int nid = this->DecodePosition(ridx);
if (tree[nid].is_leaf()) {
// mark finish when it is not a fresh leaf
if (tree[nid].cright() == -1) {
position[ridx] = ~nid;
}
} else {
// push to default branch
if (tree[nid].default_left()) {
this->SetEncodePosition(ridx, tree[nid].cleft());
} else {
this->SetEncodePosition(ridx, tree[nid].cright());
}
}
}
}
/*!
* \brief this is helper function uses column based data structure,
* update all positions into nondefault branch, if any, ignore the default branch
* \param nodes the set of nodes that contains the split to be used
* \param p_fmat feature matrix needed for tree construction
* \param tree the regression tree structure
*/
virtual void SetNonDefaultPositionCol(const std::vector<int> &nodes,
IFMatrix *p_fmat, const RegTree &tree) {
// step 1, classify the non-default data into right places
std::vector<unsigned> fsplits;
for (size_t i = 0; i < nodes.size(); ++i) {
const int nid = nodes[i];
if (!tree[nid].is_leaf()) {
fsplits.push_back(tree[nid].split_index());
}
}
std::sort(fsplits.begin(), fsplits.end());
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fsplits);
while (iter->Next()) {
const ColBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
ColBatch::Inst col = batch[i];
const bst_uint fid = batch.col_index[i];
const bst_omp_uint ndata = static_cast<bst_omp_uint>(col.length);
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue;
const int nid = this->DecodePosition(ridx);
// go back to parent, correct those who are not default
if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) {
if(fvalue < tree[nid].split_cond()) {
this->SetEncodePosition(ridx, tree[nid].cleft());
} else {
this->SetEncodePosition(ridx, tree[nid].cright());
}
}
}
}
}
}
/*! \brief training parameter of tree grower */ /*! \brief training parameter of tree grower */
TrainParam param; TrainParam param;
/*! \brief queue of nodes to be expanded */ /*! \brief queue of nodes to be expanded */