new change for mpi
This commit is contained in:
parent
a21df0770d
commit
0cf2dd39ea
@ -42,6 +42,7 @@ class GBLinear : public IGradBooster {
|
||||
model.InitModel();
|
||||
}
|
||||
virtual void DoBoost(IFMatrix *p_fmat,
|
||||
int64_t buffer_offset,
|
||||
const BoosterInfo &info,
|
||||
std::vector<bst_gpair> *in_gpair) {
|
||||
std::vector<bst_gpair> &gpair = *in_gpair;
|
||||
|
||||
@ -41,11 +41,14 @@ class IGradBooster {
|
||||
/*!
|
||||
* \brief peform update to the model(boosting)
|
||||
* \param p_fmat feature matrix that provide access to features
|
||||
* \param buffer_offset buffer index offset of these instances, if equals -1
|
||||
* this means we do not have buffer index allocated to the gbm
|
||||
* \param info meta information about training
|
||||
* \param in_gpair address of the gradient pair statistics of the data
|
||||
* the booster may change content of gpair
|
||||
*/
|
||||
virtual void DoBoost(IFMatrix *p_fmat,
|
||||
int64_t buffer_offset,
|
||||
const BoosterInfo &info,
|
||||
std::vector<bst_gpair> *in_gpair) = 0;
|
||||
/*!
|
||||
|
||||
@ -19,6 +19,8 @@ namespace gbm {
|
||||
*/
|
||||
class GBTree : public IGradBooster {
|
||||
public:
|
||||
GBTree(void) {
|
||||
}
|
||||
virtual ~GBTree(void) {
|
||||
this->Clear();
|
||||
}
|
||||
@ -83,11 +85,12 @@ class GBTree : public IGradBooster {
|
||||
utils::Assert(trees.size() == 0, "GBTree: model already initialized");
|
||||
}
|
||||
virtual void DoBoost(IFMatrix *p_fmat,
|
||||
int64_t buffer_offset,
|
||||
const BoosterInfo &info,
|
||||
std::vector<bst_gpair> *in_gpair) {
|
||||
const std::vector<bst_gpair> &gpair = *in_gpair;
|
||||
if (mparam.num_output_group == 1) {
|
||||
this->BoostNewTrees(gpair, p_fmat, info, 0);
|
||||
this->BoostNewTrees(gpair, p_fmat, buffer_offset, info, 0);
|
||||
} else {
|
||||
const int ngroup = mparam.num_output_group;
|
||||
utils::Check(gpair.size() % ngroup == 0,
|
||||
@ -99,7 +102,7 @@ class GBTree : public IGradBooster {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
tmp[i] = gpair[i * ngroup + gid];
|
||||
}
|
||||
this->BoostNewTrees(tmp, p_fmat, info, gid);
|
||||
this->BoostNewTrees(tmp, p_fmat, buffer_offset, info, gid);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -190,6 +193,7 @@ class GBTree : public IGradBooster {
|
||||
// do group specific group
|
||||
inline void BoostNewTrees(const std::vector<bst_gpair> &gpair,
|
||||
IFMatrix *p_fmat,
|
||||
int64_t buffer_offset,
|
||||
const BoosterInfo &info,
|
||||
int bst_group) {
|
||||
this->InitUpdater();
|
||||
@ -206,6 +210,17 @@ class GBTree : public IGradBooster {
|
||||
for (size_t i = 0; i < updaters.size(); ++i) {
|
||||
updaters[i]->Update(gpair, p_fmat, info, new_trees);
|
||||
}
|
||||
// optimization, update buffer, if possible
|
||||
if (buffer_offset >= 0 &&
|
||||
new_trees.size() == 1 && updaters.size() > 0 &&
|
||||
updaters.back()->GetLeafPosition() != NULL) {
|
||||
utils::Check(info.num_row == p_fmat->buffered_rowset().size(),
|
||||
"distributed mode is not compatible with prob_buffer_row");
|
||||
this->UpdateBufferByPosition(p_fmat,
|
||||
buffer_offset, bst_group,
|
||||
*new_trees[0],
|
||||
updaters.back()->GetLeafPosition());
|
||||
}
|
||||
// push back to model
|
||||
for (size_t i = 0; i < new_trees.size(); ++i) {
|
||||
trees.push_back(new_trees[i]);
|
||||
@ -213,13 +228,36 @@ class GBTree : public IGradBooster {
|
||||
}
|
||||
mparam.num_trees += tparam.num_parallel_tree;
|
||||
}
|
||||
// update buffer by pre-cached position
|
||||
inline void UpdateBufferByPosition(IFMatrix *p_fmat,
|
||||
int64_t buffer_offset,
|
||||
int bst_group,
|
||||
const tree::RegTree &new_tree,
|
||||
const int* leaf_position) {
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
const int64_t bid = mparam.BufferOffset(buffer_offset + ridx, bst_group);
|
||||
const int tid = leaf_position[ridx];
|
||||
utils::Assert(pred_counter[bid] == trees.size(), "cached buffer not up to date");
|
||||
utils::Assert(tid >= 0, "invalid leaf position");
|
||||
pred_buffer[bid] += new_tree[tid].leaf_value();
|
||||
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
|
||||
pred_buffer[bid + i + 1] += new_tree.leafvec(tid)[i];
|
||||
}
|
||||
pred_counter[bid] += 1;
|
||||
}
|
||||
}
|
||||
// make a prediction for a single instance
|
||||
inline void Pred(const RowBatch::Inst &inst,
|
||||
int64_t buffer_index,
|
||||
int bst_group,
|
||||
unsigned root_index,
|
||||
tree::RegTree::FVec *p_feats,
|
||||
float *out_pred, size_t stride, unsigned ntree_limit) {
|
||||
float *out_pred, size_t stride,
|
||||
unsigned ntree_limit) {
|
||||
size_t itop = 0;
|
||||
float psum = 0.0f;
|
||||
// sum of leaf vector
|
||||
|
||||
@ -173,7 +173,7 @@ class BoostLearner {
|
||||
inline void UpdateOneIter(int iter, const DMatrix &train) {
|
||||
this->PredictRaw(train, &preds_);
|
||||
obj_->GetGradient(preds_, train.info, iter, &gpair_);
|
||||
gbm_->DoBoost(train.fmat(), train.info.info, &gpair_);
|
||||
gbm_->DoBoost(train.fmat(), this->FindBufferOffset(train), train.info.info, &gpair_);
|
||||
}
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration
|
||||
@ -335,7 +335,7 @@ class BoostLearner {
|
||||
// gradient pairs
|
||||
std::vector<bst_gpair> gpair_;
|
||||
|
||||
private:
|
||||
protected:
|
||||
// cache entry object that helps handle feature caching
|
||||
struct CacheEntry {
|
||||
const DMatrix *mat_;
|
||||
|
||||
@ -13,8 +13,8 @@ IUpdater* CreateUpdater(const char *name) {
|
||||
using namespace std;
|
||||
if (!strcmp(name, "prune")) return new TreePruner();
|
||||
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
|
||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
|
||||
if (!strcmp(name, "distcol")) return new DistColMaker<GradStats>();
|
||||
if (!strcmp(name, "grow_colmaker5")) return new ColMaker< CVGradStats<5> >();
|
||||
if (!strcmp(name, "grow_colmaker3")) return new ColMaker< CVGradStats<3> >();
|
||||
utils::Error("unknown updater:%s", name);
|
||||
|
||||
@ -37,6 +37,16 @@ class IUpdater {
|
||||
IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
const std::vector<RegTree*> &trees) = 0;
|
||||
|
||||
/*!
|
||||
* \brief this is simply a function for optimizing performance
|
||||
* this function asks the updater to return the leaf position of each instance in the p_fmat,
|
||||
* if it is cached in the updater, if it is not available, return NULL
|
||||
* \return array of leaf position of each instance in the last updated tree
|
||||
*/
|
||||
virtual const int* GetLeafPosition(void) const {
|
||||
return NULL;
|
||||
}
|
||||
// destructor
|
||||
virtual ~IUpdater(void) {}
|
||||
};
|
||||
|
||||
@ -38,7 +38,9 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
// update position after the tree is pruned
|
||||
builder.UpdatePosition(p_fmat, *trees[0]);
|
||||
}
|
||||
|
||||
virtual const int* GetLeafPosition(void) const {
|
||||
return builder.GetLeafPosition();
|
||||
}
|
||||
private:
|
||||
inline void SyncTrees(RegTree *tree) {
|
||||
std::string s_model;
|
||||
@ -71,6 +73,9 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
this->position[ridx] = nid;
|
||||
}
|
||||
}
|
||||
virtual const int* GetLeafPosition(void) const {
|
||||
return BeginPtr(this->position);
|
||||
}
|
||||
protected:
|
||||
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
|
||||
IFMatrix *p_fmat, const RegTree &tree) {
|
||||
|
||||
@ -44,7 +44,7 @@ class Booster: public learner::BoostLearner {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
gpair_[j] = bst_gpair(grad[j], hess[j]);
|
||||
}
|
||||
gbm_->DoBoost(train.fmat(), train.info.info, &gpair_);
|
||||
gbm_->DoBoost(train.fmat(), this->FindBufferOffset(train), train.info.info, &gpair_);
|
||||
}
|
||||
inline void CheckInitModel(void) {
|
||||
if (!init_model) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user