middle version
This commit is contained in:
@@ -110,6 +110,10 @@ class TreeModel {
|
||||
inline bool is_left_child(void) const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
/*! \brief whether this node is deleted */
|
||||
inline bool is_deleted(void) const {
|
||||
return sindex_ == std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
/*! \brief whether current node is root */
|
||||
inline bool is_root(void) const {
|
||||
return parent_ == -1;
|
||||
@@ -144,7 +148,11 @@ class TreeModel {
|
||||
this->cleft_ = -1;
|
||||
this->cright_ = right;
|
||||
}
|
||||
|
||||
/*! \brief mark that this node is deleted */
|
||||
inline void mark_delete(void) {
|
||||
this->sindex_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
private:
|
||||
friend class TreeModel<TSplitCond, TNodeStat>;
|
||||
/*!
|
||||
@@ -197,11 +205,11 @@ class TreeModel {
|
||||
leaf_vector.resize(param.num_nodes * param.size_leaf_vector);
|
||||
return nd;
|
||||
}
|
||||
// delete a tree node
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
inline void DeleteNode(int nid) {
|
||||
utils::Assert(nid >= param.num_roots, "can not delete root");
|
||||
deleted_nodes.push_back(nid);
|
||||
nodes[nid].set_parent(-1);
|
||||
nodes[nid].mark_delete();
|
||||
++param.num_deleted;
|
||||
}
|
||||
|
||||
|
||||
@@ -345,6 +345,10 @@ struct SplitEntry{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*! \brief same as update, used by AllReduce*/
|
||||
inline void Reduce(const SplitEntry &e) {
|
||||
this->Update(e);
|
||||
}
|
||||
/*!\return feature index to split on */
|
||||
inline unsigned split_index(void) const {
|
||||
return sindex & ((1U << 31) - 1U);
|
||||
|
||||
@@ -486,13 +486,17 @@ class ColMaker: public IUpdater {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
const int nid = position[ridx];
|
||||
if (nid >= 0) {
|
||||
if (tree[nid].is_leaf()) {
|
||||
position[ridx] = - nid - 1;
|
||||
int nid = position[ridx];
|
||||
if (nid < 0) nid = ~nid;
|
||||
if (tree[nid].is_leaf()) {
|
||||
position[ridx] = ~nid;
|
||||
} else {
|
||||
// push to default branch, correct latter
|
||||
int pid = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
|
||||
if (position[ridx] < 0) {
|
||||
position[ridx] = ~pid;
|
||||
} else {
|
||||
// push to default branch, correct latter
|
||||
position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
|
||||
position[ridx] = pid;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -535,7 +539,8 @@ class ColMaker: public IUpdater {
|
||||
const bst_uint ridx = col[j].index;
|
||||
const float fvalue = col[j].fvalue;
|
||||
int nid = position[ridx];
|
||||
if (nid < 0) continue;
|
||||
if (nid < 0) nid = ~nid;
|
||||
|
||||
// go back to parent, correct those who are not default
|
||||
nid = tree[nid].parent();
|
||||
if (tree[nid].split_index() == fid) {
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include "../utils/bitmap.h"
|
||||
#include "../utils/io.h"
|
||||
#include "../sync/sync.h"
|
||||
#include "./updater_colmaker-inl.hpp"
|
||||
#include "./updater_prune-inl.hpp"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -19,6 +22,7 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
param.SetParam(name, val);
|
||||
pruner.SetParam(name, val);
|
||||
}
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
@@ -26,15 +30,46 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
const std::vector<RegTree*> &trees) {
|
||||
TStats::CheckInfo(info);
|
||||
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
|
||||
// build the tree
|
||||
builder.Update(gpair, p_fmat, info, trees[0]);
|
||||
// prune the tree
|
||||
pruner.Update(gpair, p_fmat, info, trees);
|
||||
this->SyncTrees(trees[0]);
|
||||
// update position after the tree is pruned
|
||||
builder.UpdatePosition(p_fmat, *trees[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
inline void SyncTrees(RegTree *tree) {
|
||||
std::string s_model;
|
||||
utils::MemoryBufferStream fs(&s_model);
|
||||
int rank = sync::GetRank();
|
||||
if (rank == 0) {
|
||||
tree->SaveModel(fs);
|
||||
sync::Bcast(&s_model, 0);
|
||||
} else {
|
||||
sync::Bcast(&s_model, 0);
|
||||
tree->LoadModel(fs);
|
||||
}
|
||||
}
|
||||
struct Builder : public ColMaker<TStats>::Builder {
|
||||
public:
|
||||
Builder(const TrainParam ¶m)
|
||||
: ColMaker<TStats>::Builder(param) {
|
||||
}
|
||||
protected:
|
||||
inline void UpdatePosition(IFMatrix *p_fmat, const RegTree &tree) {
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
int nid = this->position[ridx];
|
||||
if (nid < 0) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
protected:
|
||||
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
|
||||
IFMatrix *p_fmat, const RegTree &tree) {
|
||||
// step 2, classify the non-default data into right places
|
||||
@@ -80,8 +115,8 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
}
|
||||
}
|
||||
// communicate bitmap
|
||||
//sync::AllReduce();
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR);
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
// get the new position
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
@@ -100,19 +135,29 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
}
|
||||
// synchronize the best solution of each node
|
||||
virtual void SyncBestSolution(const std::vector<int> &qexpand) {
|
||||
std::vector<SplitEntry> vec;
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
for (int tid = 0; tid < this->nthread; ++tid) {
|
||||
this->snode[nid].best.Update(this->stemp[tid][nid].best);
|
||||
}
|
||||
vec.push_back(this->snode[nid].best);
|
||||
}
|
||||
// communicate best solution
|
||||
// sync::AllReduce
|
||||
reducer.AllReduce(BeginPtr(vec), vec.size());
|
||||
// assign solution back
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
this->snode[nid].best = vec[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
utils::BitMap bitmap;
|
||||
sync::Reducer<SplitEntry> reducer;
|
||||
};
|
||||
// we directly introduce pruner here
|
||||
TreePruner pruner;
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
// pointer to the builder
|
||||
|
||||
Reference in New Issue
Block a user