Small cleanup to hist tree method. (#7735)

* Remove special optimization using number of bins.
* Remove 1-based index for column sampling.
* Remove data layout.
* Unify update prediction cache.
This commit is contained in:
Jiaming Yuan
2022-03-20 03:44:55 +08:00
committed by GitHub
parent 718472dbe2
commit 996cc705af
9 changed files with 140 additions and 205 deletions

View File

@@ -156,9 +156,8 @@ class ColumnSampler {
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
*/
void Init(int64_t num_col, std::vector<float> feature_weights,
float colsample_bynode, float colsample_bylevel,
float colsample_bytree, bool skip_index_0 = false) {
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
float colsample_bylevel, float colsample_bytree) {
feature_weights_ = std::move(feature_weights);
colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree;
@@ -169,10 +168,8 @@ class ColumnSampler {
}
Reset();
int begin_idx = skip_index_0 ? 1 : 0;
feature_set_tree_->Resize(num_col - begin_idx);
std::iota(feature_set_tree_->HostVector().begin(),
feature_set_tree_->HostVector().end(), begin_idx);
feature_set_tree_->Resize(num_col);
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
}

View File

@@ -55,8 +55,6 @@ class RowSetCollection {
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
return e;
}
@@ -75,14 +73,10 @@ class RowSetCollection {
CHECK_EQ(elem_of_each_node_.size(), 0U);
if (row_indices_.empty()) { // edge case: empty instance set
// assign arbitrary address here, to bypass nullptr check
// (nullptr usually indicates a nonexistent rowset, but we want to
// indicate a valid rowset that happens to have zero length and occupies
// the whole instance set)
// this is okay, as BuildHist will compute (end-begin) as the set size
const size_t* begin = reinterpret_cast<size_t*>(20);
const size_t* end = begin;
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
constexpr size_t* kBegin = nullptr;
constexpr size_t* kEnd = nullptr;
static_assert(kEnd - kBegin == 0, "");
elem_of_each_node_.emplace_back(Elem(kBegin, kEnd, 0));
return;
}
@@ -93,15 +87,19 @@ class RowSetCollection {
std::vector<size_t>* Data() { return &row_indices_; }
// split rowset into two
inline void AddSplit(unsigned node_id,
unsigned left_node_id,
unsigned right_node_id,
size_t n_left,
size_t n_right) {
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
size_t n_left, size_t n_right) {
const Elem e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr);
size_t* all_begin = dmlc::BeginPtr(row_indices_);
size_t* begin = all_begin + (e.begin - all_begin);
size_t* all_begin{nullptr};
size_t* begin{nullptr};
if (e.begin == nullptr) {
CHECK_EQ(n_left, 0);
CHECK_EQ(n_right, 0);
} else {
all_begin = dmlc::BeginPtr(row_indices_);
begin = all_begin + (e.begin - all_begin);
}
CHECK_EQ(n_left + n_right, e.Size());
CHECK_LE(begin + n_left, e.end);

View File

@@ -266,6 +266,9 @@ class MemStackAllocator {
throw std::bad_alloc{};
}
}
MemStackAllocator(size_t required_size, T init) : MemStackAllocator{required_size} {
std::fill_n(ptr_, required_size_, init);
}
~MemStackAllocator() {
if (required_size_ > MaxStackSize) {