@@ -17,7 +17,6 @@
|
||||
#include "param.h"
|
||||
#include "constraints.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/bitmap.h"
|
||||
#include "split_evaluator.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -618,171 +617,10 @@ class ColMaker: public TreeUpdater {
|
||||
};
|
||||
};
|
||||
|
||||
// distributed column maker
|
||||
class DistColMaker : public ColMaker {
|
||||
public:
|
||||
void Configure(const Args& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
pruner_.reset(TreeUpdater::Create("prune", tparam_));
|
||||
pruner_->Configure(args);
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
spliteval_->Init(¶m_);
|
||||
}
|
||||
|
||||
char const* Name() const override {
|
||||
return "distcol";
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
|
||||
this->LazyGetColumnDensity(dmat);
|
||||
Builder builder(
|
||||
param_,
|
||||
colmaker_param_,
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
interaction_constraints_, column_densities_);
|
||||
// build the tree
|
||||
builder.Update(gpair->ConstHostVector(), dmat, trees[0]);
|
||||
//// prune the tree, note that pruner will sync the tree
|
||||
pruner_->Update(gpair, dmat, trees);
|
||||
// update position after the tree is pruned
|
||||
builder.UpdatePosition(dmat, *trees[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
class Builder : public ColMaker::Builder {
|
||||
public:
|
||||
explicit Builder(const TrainParam ¶m,
|
||||
ColMakerTrainParam const &colmaker_train_param,
|
||||
std::unique_ptr<SplitEvaluator> spliteval,
|
||||
FeatureInteractionConstraintHost _interaction_constraints,
|
||||
const std::vector<float> &column_densities)
|
||||
: ColMaker::Builder(param, colmaker_train_param,
|
||||
std::move(spliteval),
|
||||
std::move(_interaction_constraints),
|
||||
column_densities) {}
|
||||
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
|
||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
|
||||
int nid = this->DecodePosition(ridx);
|
||||
while (tree[nid].IsDeleted()) {
|
||||
nid = tree[nid].Parent();
|
||||
CHECK_GE(nid, 0);
|
||||
}
|
||||
this->position_[ridx] = nid;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetNonDefaultPosition(const std::vector<int> &qexpand, DMatrix *p_fmat,
|
||||
const RegTree &tree) override {
|
||||
// step 2, classify the non-default data into right places
|
||||
std::vector<unsigned> fsplits;
|
||||
for (int nid : qexpand) {
|
||||
if (!tree[nid].IsLeaf()) {
|
||||
fsplits.push_back(tree[nid].SplitIndex());
|
||||
}
|
||||
}
|
||||
// get the candidate split index
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
while (fsplits.size() != 0 && fsplits.back() >= p_fmat->Info().num_col_) {
|
||||
fsplits.pop_back();
|
||||
}
|
||||
// bitmap is only word concurrent, set to bool first
|
||||
{
|
||||
auto ndata = static_cast<bst_omp_uint>(this->position_.size());
|
||||
boolmap_.resize(ndata);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
boolmap_[j] = 0;
|
||||
}
|
||||
}
|
||||
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
|
||||
for (auto fid : fsplits) {
|
||||
auto col = batch[fid];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
const bst_uint ridx = col[j].index;
|
||||
const bst_float fvalue = col[j].fvalue;
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
if (!tree[nid].IsLeaf() && tree[nid].SplitIndex() == fid) {
|
||||
if (fvalue < tree[nid].SplitCond()) {
|
||||
if (!tree[nid].DefaultLeft()) boolmap_[ridx] = 1;
|
||||
} else {
|
||||
if (tree[nid].DefaultLeft()) boolmap_[ridx] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitmap_.InitFromBool(boolmap_);
|
||||
// communicate bitmap
|
||||
rabit::Allreduce<rabit::op::BitOR>(dmlc::BeginPtr(bitmap_.data), bitmap_.data.size());
|
||||
// get the new position
|
||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
|
||||
const int nid = this->DecodePosition(ridx);
|
||||
if (bitmap_.Get(ridx)) {
|
||||
CHECK(!tree[nid].IsLeaf()) << "inconsistent reduce information";
|
||||
if (tree[nid].DefaultLeft()) {
|
||||
this->SetEncodePosition(ridx, tree[nid].RightChild());
|
||||
} else {
|
||||
this->SetEncodePosition(ridx, tree[nid].LeftChild());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// synchronize the best solution of each node
|
||||
void SyncBestSolution(const std::vector<int> &qexpand) override {
|
||||
std::vector<SplitEntry> vec;
|
||||
for (int nid : qexpand) {
|
||||
for (int tid = 0; tid < this->nthread_; ++tid) {
|
||||
this->snode_[nid].best.Update(this->stemp_[tid][nid].best);
|
||||
}
|
||||
vec.push_back(this->snode_[nid].best);
|
||||
}
|
||||
// TODO(tqchen) lazy version
|
||||
// communicate best solution
|
||||
reducer_.Allreduce(dmlc::BeginPtr(vec), vec.size());
|
||||
// assign solution back
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
this->snode_[nid].best = vec[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
common::BitMap bitmap_;
|
||||
std::vector<int> boolmap_;
|
||||
rabit::Reducer<SplitEntry, SplitEntry::Reduce> reducer_;
|
||||
};
|
||||
// we directly introduce pruner here
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
// training parameter
|
||||
TrainParam param_;
|
||||
// Cloned for each builder instantiation
|
||||
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")
|
||||
.describe("Grow tree with parallelization over columns.")
|
||||
.set_body([]() {
|
||||
return new ColMaker();
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol")
|
||||
.describe("Distributed column split version of tree maker.")
|
||||
.set_body([]() {
|
||||
return new DistColMaker();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user