#ifndef XGBOOST_TREE_UPDATER_DISTCOL_INL_HPP_ #define XGBOOST_TREE_UPDATER_DISTCOL_INL_HPP_ /*! * \file updater_distcol-inl.hpp * \brief beta distributed version that takes a sub-column * and construct a tree * \author Tianqi Chen */ #include "../utils/bitmap.h" #include "../utils/io.h" #include "../sync/sync.h" #include "./updater_colmaker-inl.hpp" #include "./updater_prune-inl.hpp" namespace xgboost { namespace tree { template class DistColMaker : public ColMaker { public: DistColMaker(void) : builder(param) {} virtual ~DistColMaker(void) {} // set training parameter virtual void SetParam(const char *name, const char *val) { param.SetParam(name, val); pruner.SetParam(name, val); } virtual void Update(const std::vector &gpair, IFMatrix *p_fmat, const BoosterInfo &info, const std::vector &trees) { TStats::CheckInfo(info); utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time"); // build the tree builder.Update(gpair, p_fmat, info, trees[0]); // prune the tree pruner.Update(gpair, p_fmat, info, trees); this->SyncTrees(trees[0]); // update position after the tree is pruned builder.UpdatePosition(p_fmat, *trees[0]); } private: inline void SyncTrees(RegTree *tree) { std::string s_model; utils::MemoryBufferStream fs(&s_model); int rank = sync::GetRank(); if (rank == 0) { tree->SaveModel(fs); sync::Bcast(&s_model, 0); } else { sync::Bcast(&s_model, 0); tree->LoadModel(fs); } } struct Builder : public ColMaker::Builder { public: Builder(const TrainParam ¶m) : ColMaker::Builder(param) { } inline void UpdatePosition(IFMatrix *p_fmat, const RegTree &tree) { const std::vector &rowset = p_fmat->buffered_rowset(); const bst_omp_uint ndata = static_cast(rowset.size()); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { const bst_uint ridx = rowset[i]; int nid = this->position[ridx]; if (nid < 0) { } } } protected: virtual void SetNonDefaultPosition(const std::vector &qexpand, IFMatrix *p_fmat, const RegTree &tree) { // step 2, classify the non-default data into right places std::vector fsplits; for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; if (!tree[nid].is_leaf()) { fsplits.push_back(tree[nid].split_index()); } } // get the candidate split index std::sort(fsplits.begin(), fsplits.end()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); while (fsplits.size() != 0 && fsplits.back() >= p_fmat->NumCol()) { fsplits.pop_back(); } // setup BitMap bitmap.Resize(this->position.size()); bitmap.Clear(); utils::IIterator *iter = p_fmat->ColIterator(fsplits); while (iter->Next()) { const ColBatch &batch = iter->Value(); for (size_t i = 0; i < batch.size; ++i) { ColBatch::Inst col = batch[i]; const bst_uint fid = batch.col_index[i]; const bst_omp_uint ndata = static_cast(col.length); #pragma omp parallel for schedule(static) for (bst_omp_uint j = 0; j < ndata; ++j) { const bst_uint ridx = col[j].index; const float fvalue = col[j].fvalue; const int nid = this->DecodePosition(ridx); if (!tree[nid].is_leaf() && tree[nid].split_index() == fid) { if (fvalue < tree[nid].split_cond()) { if (!tree[nid].default_left()) bitmap.SetTrue(ridx); } else { if (tree[nid].default_left()) bitmap.SetTrue(ridx); } } } } } // communicate bitmap sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR); const std::vector &rowset = p_fmat->buffered_rowset(); // get the new position const bst_omp_uint ndata = static_cast(rowset.size()); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { const bst_uint ridx = rowset[i]; const int nid = this->DecodePosition(ridx); if (bitmap.Get(ridx)) { utils::Assert(!tree[nid].is_leaf(), "inconsistent reduce information"); if (tree[nid].default_left()) { this->SetEncodePosition(ridx, tree[nid].cright()); } else { this->SetEncodePosition(ridx, tree[nid].cright()); } } } } // synchronize the best solution of each node virtual void SyncBestSolution(const std::vector &qexpand) { std::vector vec; for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; for (int tid = 0; tid < this->nthread; ++tid) { this->snode[nid].best.Update(this->stemp[tid][nid].best); } vec.push_back(this->snode[nid].best); } // communicate best solution reducer.AllReduce(BeginPtr(vec), vec.size()); // assign solution back for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; this->snode[nid].best = vec[i]; } } private: utils::BitMap bitmap; sync::Reducer reducer; }; // we directly introduce pruner here TreePruner pruner; // training parameter TrainParam param; // pointer to the builder Builder builder; }; } // namespace tree } // namespace xgboost #endif