xgboost/src/common/row_set.h
Tianqi Chen d581a3d0e7 [UPDATE] Update rabit and threadlocal (#2114)
* [UPDATE] Update rabit and threadlocal

* minor fix to make build system happy

* upgrade requirement to g++4.8

* upgrade dmlc-core

* update travis
2017-03-16 18:48:37 -07:00

105 lines
3.1 KiB
C++

/*!
* Copyright 2017 by Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
*/
#ifndef XGBOOST_COMMON_ROW_SET_H_
#define XGBOOST_COMMON_ROW_SET_H_
#include <xgboost/data.h>
#include <algorithm>
#include <vector>
namespace xgboost {
namespace common {
/*! \brief collection of rowset */
class RowSetCollection {
public:
/*! \brief subset of rows */
struct Elem {
const bst_uint* begin;
const bst_uint* end;
Elem(void)
: begin(nullptr), end(nullptr) {}
Elem(const bst_uint* begin,
const bst_uint* end)
: begin(begin), end(end) {}
inline size_t size() const {
return end - begin;
}
};
/* \brief specifies how to split a rowset into two */
struct Split {
std::vector<bst_uint> left;
std::vector<bst_uint> right;
};
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
return e;
}
// clear up things
inline void Clear() {
row_indices_.clear();
elem_of_each_node_.clear();
}
// initialize node id 0->everything
inline void Init() {
CHECK_EQ(elem_of_each_node_.size(), 0U);
const bst_uint* begin = dmlc::BeginPtr(row_indices_);
const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
elem_of_each_node_.emplace_back(Elem(begin, end));
}
// split rowset into two
inline void AddSplit(unsigned node_id,
const std::vector<Split>& row_split_tloc,
unsigned left_node_id,
unsigned right_node_id) {
const Elem e = elem_of_each_node_[node_id];
const unsigned nthread = row_split_tloc.size();
CHECK(e.begin != nullptr);
bst_uint* all_begin = dmlc::BeginPtr(row_indices_);
bst_uint* begin = all_begin + (e.begin - all_begin);
bst_uint* it = begin;
// TODO(hcho3): parallelize this section
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
it += row_split_tloc[tid].left.size();
}
bst_uint* split_pt = it;
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
it += row_split_tloc[tid].right.size();
}
if (left_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr));
}
if (right_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr));
}
elem_of_each_node_[left_node_id] = Elem(begin, split_pt);
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr);
}
// stores the row indices in the set
std::vector<bst_uint> row_indices_;
private:
// vector: node_id -> elements
std::vector<Elem> elem_of_each_node_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_ROW_SET_H_