From ca9646874572af2ff245ef817b220d802a8017bf Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 2 Nov 2014 21:52:59 -0800 Subject: [PATCH] everything is ready, except for propose --- src/tree/model.h | 2 +- src/tree/updater_histmaker-inl.hpp | 127 +++++++++++++++++++++++------ 2 files changed, 105 insertions(+), 24 deletions(-) diff --git a/src/tree/model.h b/src/tree/model.h index 84010bcc0..a330e2960 100644 --- a/src/tree/model.h +++ b/src/tree/model.h @@ -68,7 +68,7 @@ class TreeModel { } }; /*! \brief tree node */ - class Node{ + class Node { public: /*! \brief index of left child */ inline int cleft(void) const { diff --git a/src/tree/updater_histmaker-inl.hpp b/src/tree/updater_histmaker-inl.hpp index 48f341469..afbafdeac 100644 --- a/src/tree/updater_histmaker-inl.hpp +++ b/src/tree/updater_histmaker-inl.hpp @@ -29,7 +29,7 @@ class HistMaker: public IUpdater { param.learning_rate = lr / trees.size(); // build tree for (size_t i = 0; i < trees.size(); ++i) { - // TODO + this->Update(gpair, p_fmat, info, trees[i]); } param.learning_rate = lr; } @@ -80,25 +80,16 @@ class HistMaker: public IUpdater { // per thread histset std::vector hset; // initialize the hist set - inline void Init(const TrainParam ¶m) { - int nthread; - #pragma omp parallel - { - nthread = omp_get_num_threads(); - } + inline void Init(const TrainParam ¶m, int nthread) { hset.resize(nthread); // cleanup statistics - #pragma omp parallel - { - int tid = omp_get_thread_num(); + for (int tid = 0; tid < nthread; ++tid) { for (size_t i = 0; i < hset[tid].data.size(); ++i) { hset[tid].data[i].Clear(); } - } - for (int i = 0; i < nthread; ++i) { - hset[i].rptr = BeginPtr(rptr); - hset[i].cut = BeginPtr(cut); - hset[i].data.resize(cut.size(), TStats(param)); + hset[tid].rptr = BeginPtr(rptr); + hset[tid].cut = BeginPtr(cut); + hset[tid].data.resize(cut.size(), TStats(param)); } } // aggregate all statistics to hset[0] @@ -119,7 +110,7 @@ class HistMaker: public IUpdater { inline size_t Size(void) const { return rptr.size() - 1; } - }; + }; // training parameter TrainParam param; // workspace of thread @@ -132,30 +123,116 @@ class HistMaker: public IUpdater { std::vector node2workindex; // reducer for histogram sync::Reducer histred; + + // helper function to get to next level of the tree + // must work on non-leaf node + inline static int NextLevel(const SparseBatch::Inst &inst, const RegTree &tree, int nid) { + const RegTree::Node &n = tree[nid]; + bst_uint findex = n.split_index(); + for (unsigned i = 0; i < inst.length; ++i) { + if (findex == inst[i].index) { + if (inst[i].fvalue < n.split_cond()) { + return n.cleft(); + } else { + return n.cright(); + } + } + } + return n.cdefault(); + } + private: virtual void Update(const std::vector &gpair, IFMatrix *p_fmat, const BoosterInfo &info, RegTree *p_tree) { - //this->InitData(gpair, *p_fmat, info.root_index, *p_tree); - //this->InitNewNode(qexpand_, gpair, *p_fmat, info, *p_tree); - for (int depth = 0; depth < param.max_depth; ++depth) { + this->InitData(gpair, *p_fmat, info.root_index, *p_tree); + this->UpdateNode2WorkIndex(*p_tree); + for (int depth = 0; depth < param.max_depth; ++depth) { this->FindSplit(depth, gpair, p_fmat, info, p_tree); - //this->ResetPosition(qexpand_, p_fmat, *p_tree); - //this->UpdateQueueExpand(*p_tree, &qexpand_); - //this->InitNewNode(qexpand_, gpair, *p_fmat, info, *p_tree); + this->UpdateQueueExpand(*p_tree); + this->UpdateNode2WorkIndex(*p_tree); // if nothing left to be expand, break if (qexpand.size() == 0) break; } } + // initialize temp data structure + inline void InitData(const std::vector &gpair, + const IFMatrix &fmat, + const std::vector &root_index, const RegTree &tree) { + utils::Assert(tree.param.num_nodes == tree.param.num_roots, "HistMaker: can only grow new tree"); + {// setup position + position.resize(gpair.size()); + if (root_index.size() == 0) { + std::fill(position.begin(), position.end(), 0); + } else { + for (size_t i = 0; i < position.size(); ++i) { + position[i] = root_index[i]; + utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, + "root index exceed setting"); + } + } + // mark delete for the deleted datas + for (size_t i = 0; i < position.size(); ++i) { + if (gpair[i].hess < 0.0f) position[i] = ~position[i]; + } + // mark subsample + if (param.subsample < 1.0f) { + for (size_t i = 0; i < position.size(); ++i) { + if (gpair[i].hess < 0.0f) continue; + if (random::SampleBinary(param.subsample) == 0) position[i] = ~position[i]; + } + } + } + {// expand query + qexpand.reserve(256); qexpand.clear(); + for (int i = 0; i < tree.param.num_roots; ++i) { + qexpand.push_back(i); + } + } + } + /*! \brief update queue expand add in new leaves */ + inline void UpdateQueueExpand(const RegTree &tree) { + std::vector newnodes; + for (size_t i = 0; i < qexpand.size(); ++i) { + const int nid = qexpand[i]; + if (!tree[nid].is_leaf()) { + newnodes.push_back(tree[nid].cleft()); + newnodes.push_back(tree[nid].cright()); + } + } + // use new nodes for qexpand + qexpand = newnodes; + } + inline void UpdateNode2WorkIndex(const RegTree &tree) { + // update the node2workindex + std::fill(node2workindex.begin(), node2workindex.end(), -1); + node2workindex.resize(tree.param.num_nodes); + for (size_t i = 0; i < qexpand.size(); ++i) { + node2workindex[qexpand[i]] = static_cast(i); + } + } + // this function does two jobs + // (1) reset the position in array position, to be the latest leaf id + // (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly + virtual void ResetPosAndPropose(IFMatrix *p_fmat, + const BoosterInfo &info, + const RegTree &tree) { + + } // create histogram for a setup histset inline void CreateHist(const std::vector &gpair, IFMatrix *p_fmat, const BoosterInfo &info, const RegTree &tree) { bst_uint num_feature = tree.param.num_feature; + int nthread; + #pragma omp parallel + { + nthread = omp_get_num_threads(); + } // intialize work space - wspace.Init(param); + wspace.Init(param, nthread); // start accumulating statistics utils::IIterator *iter = p_fmat->RowIterator(); iter->BeforeFirst(); @@ -225,6 +302,8 @@ class HistMaker: public IUpdater { const BoosterInfo &info, RegTree *p_tree) { const bst_uint num_feature = p_tree->param.num_feature; + // reset and propose candidate split + this->ResetPosAndPropose(p_fmat, info, *p_tree); // create histogram this->CreateHist(gpair, p_fmat, info, *p_tree); // get the best split condition for each node @@ -265,6 +344,8 @@ class HistMaker: public IUpdater { } } }; + + } // namespace tree } // namespace xgboost #endif // XGBOOST_TREE_UPDATER_HISTMAKER_INL_HPP_