Small cleanup for rowset collection. (#10401)
This commit is contained in:
parent
e5f1720656
commit
2b400b18d5
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by Contributors
|
* Copyright 2017-2024, XGBoost Contributors
|
||||||
* \file row_set.h
|
* \file row_set.h
|
||||||
* \brief Quick Utility to compute subset of rows
|
* \brief Quick Utility to compute subset of rows
|
||||||
* \author Philip Cho, Tianqi Chen
|
* \author Philip Cho, Tianqi Chen
|
||||||
@ -7,15 +7,17 @@
|
|||||||
#ifndef XGBOOST_COMMON_ROW_SET_H_
|
#ifndef XGBOOST_COMMON_ROW_SET_H_
|
||||||
#define XGBOOST_COMMON_ROW_SET_H_
|
#define XGBOOST_COMMON_ROW_SET_H_
|
||||||
|
|
||||||
#include <xgboost/data.h>
|
#include <cstddef> // for size_t
|
||||||
#include <algorithm>
|
#include <iterator> // for distance
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace xgboost {
|
#include "xgboost/base.h" // for bst_node_t
|
||||||
namespace common {
|
#include "xgboost/logging.h" // for CHECK
|
||||||
/*! \brief collection of rowset */
|
|
||||||
|
namespace xgboost::common {
|
||||||
|
/**
|
||||||
|
* @brief Collection of rows for each tree node.
|
||||||
|
*/
|
||||||
class RowSetCollection {
|
class RowSetCollection {
|
||||||
public:
|
public:
|
||||||
RowSetCollection() = default;
|
RowSetCollection() = default;
|
||||||
@ -24,84 +26,78 @@ class RowSetCollection {
|
|||||||
RowSetCollection& operator=(RowSetCollection const&) = delete;
|
RowSetCollection& operator=(RowSetCollection const&) = delete;
|
||||||
RowSetCollection& operator=(RowSetCollection&&) = default;
|
RowSetCollection& operator=(RowSetCollection&&) = default;
|
||||||
|
|
||||||
/*! \brief data structure to store an instance set, a subset of
|
/**
|
||||||
* rows (instances) associated with a particular node in a decision
|
* @brief data structure to store an instance set, a subset of rows (instances)
|
||||||
* tree. */
|
* associated with a particular node in a decision tree.
|
||||||
|
*/
|
||||||
struct Elem {
|
struct Elem {
|
||||||
const size_t* begin{nullptr};
|
std::size_t const* begin{nullptr};
|
||||||
const size_t* end{nullptr};
|
std::size_t const* end{nullptr};
|
||||||
bst_node_t node_id{-1};
|
bst_node_t node_id{-1};
|
||||||
// id of node associated with this instance set; -1 means uninitialized
|
// id of node associated with this instance set; -1 means uninitialized
|
||||||
Elem()
|
Elem() = default;
|
||||||
= default;
|
Elem(std::size_t const* begin, std::size_t const* end, bst_node_t node_id = -1)
|
||||||
Elem(const size_t* begin,
|
|
||||||
const size_t* end,
|
|
||||||
bst_node_t node_id = -1)
|
|
||||||
: begin(begin), end(end), node_id(node_id) {}
|
: begin(begin), end(end), node_id(node_id) {}
|
||||||
|
|
||||||
inline size_t Size() const {
|
std::size_t Size() const { return end - begin; }
|
||||||
return end - begin;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<Elem>::const_iterator begin() const { // NOLINT
|
[[nodiscard]] std::vector<Elem>::const_iterator begin() const { // NOLINT
|
||||||
return elem_of_each_node_.begin();
|
return elem_of_each_node_.cbegin();
|
||||||
|
}
|
||||||
|
[[nodiscard]] std::vector<Elem>::const_iterator end() const { // NOLINT
|
||||||
|
return elem_of_each_node_.cend();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Elem>::const_iterator end() const { // NOLINT
|
[[nodiscard]] std::size_t Size() const { return std::distance(begin(), end()); }
|
||||||
return elem_of_each_node_.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t Size() const { return std::distance(begin(), end()); }
|
/** @brief return corresponding element set given the node_id */
|
||||||
|
[[nodiscard]] Elem const& operator[](bst_node_t node_id) const {
|
||||||
/*! \brief return corresponding element set given the node_id */
|
Elem const& e = elem_of_each_node_[node_id];
|
||||||
inline const Elem& operator[](unsigned node_id) const {
|
|
||||||
const Elem& e = elem_of_each_node_[node_id];
|
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
/** @brief return corresponding element set given the node_id */
|
||||||
/*! \brief return corresponding element set given the node_id */
|
[[nodiscard]] Elem& operator[](bst_node_t node_id) {
|
||||||
inline Elem& operator[](unsigned node_id) {
|
|
||||||
Elem& e = elem_of_each_node_[node_id];
|
Elem& e = elem_of_each_node_[node_id];
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
// clear up things
|
// clear up things
|
||||||
inline void Clear() {
|
void Clear() {
|
||||||
elem_of_each_node_.clear();
|
elem_of_each_node_.clear();
|
||||||
}
|
}
|
||||||
// initialize node id 0->everything
|
// initialize node id 0->everything
|
||||||
inline void Init() {
|
void Init() {
|
||||||
CHECK_EQ(elem_of_each_node_.size(), 0U);
|
CHECK(elem_of_each_node_.empty());
|
||||||
|
|
||||||
if (row_indices_.empty()) { // edge case: empty instance set
|
if (row_indices_.empty()) { // edge case: empty instance set
|
||||||
constexpr size_t* kBegin = nullptr;
|
constexpr std::size_t* kBegin = nullptr;
|
||||||
constexpr size_t* kEnd = nullptr;
|
constexpr std::size_t* kEnd = nullptr;
|
||||||
static_assert(kEnd - kBegin == 0);
|
static_assert(kEnd - kBegin == 0);
|
||||||
elem_of_each_node_.emplace_back(kBegin, kEnd, 0);
|
elem_of_each_node_.emplace_back(kBegin, kEnd, 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t* begin = dmlc::BeginPtr(row_indices_);
|
const std::size_t* begin = dmlc::BeginPtr(row_indices_);
|
||||||
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
const std::size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
||||||
elem_of_each_node_.emplace_back(begin, end, 0);
|
elem_of_each_node_.emplace_back(begin, end, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t>* Data() { return &row_indices_; }
|
[[nodiscard]] std::vector<std::size_t>* Data() { return &row_indices_; }
|
||||||
std::vector<size_t> const* Data() const { return &row_indices_; }
|
[[nodiscard]] std::vector<std::size_t> const* Data() const { return &row_indices_; }
|
||||||
|
|
||||||
// split rowset into two
|
// split rowset into two
|
||||||
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
|
void AddSplit(bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id,
|
||||||
size_t n_left, size_t n_right) {
|
bst_idx_t n_left, bst_idx_t n_right) {
|
||||||
const Elem e = elem_of_each_node_[node_id];
|
const Elem e = elem_of_each_node_[node_id];
|
||||||
|
|
||||||
size_t* all_begin{nullptr};
|
std::size_t* all_begin{nullptr};
|
||||||
size_t* begin{nullptr};
|
std::size_t* begin{nullptr};
|
||||||
if (e.begin == nullptr) {
|
if (e.begin == nullptr) {
|
||||||
CHECK_EQ(n_left, 0);
|
CHECK_EQ(n_left, 0);
|
||||||
CHECK_EQ(n_right, 0);
|
CHECK_EQ(n_right, 0);
|
||||||
} else {
|
} else {
|
||||||
all_begin = dmlc::BeginPtr(row_indices_);
|
all_begin = row_indices_.data();
|
||||||
begin = all_begin + (e.begin - all_begin);
|
begin = all_begin + (e.begin - all_begin);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,25 +105,24 @@ class RowSetCollection {
|
|||||||
CHECK_LE(begin + n_left, e.end);
|
CHECK_LE(begin + n_left, e.end);
|
||||||
CHECK_EQ(begin + n_left + n_right, e.end);
|
CHECK_EQ(begin + n_left + n_right, e.end);
|
||||||
|
|
||||||
if (left_node_id >= elem_of_each_node_.size()) {
|
if (left_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
|
||||||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
|
elem_of_each_node_.resize(left_node_id + 1, Elem{nullptr, nullptr, -1});
|
||||||
}
|
}
|
||||||
if (right_node_id >= elem_of_each_node_.size()) {
|
if (right_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
|
||||||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
|
elem_of_each_node_.resize(right_node_id + 1, Elem{nullptr, nullptr, -1});
|
||||||
}
|
}
|
||||||
|
|
||||||
elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
|
elem_of_each_node_[left_node_id] = Elem{begin, begin + n_left, left_node_id};
|
||||||
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
|
elem_of_each_node_[right_node_id] = Elem{begin + n_left, e.end, right_node_id};
|
||||||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
|
elem_of_each_node_[node_id] = Elem{nullptr, nullptr, -1};
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// stores the row indexes in the set
|
// stores the row indexes in the set
|
||||||
std::vector<size_t> row_indices_;
|
std::vector<std::size_t> row_indices_;
|
||||||
// vector: node_id -> elements
|
// vector: node_id -> elements
|
||||||
std::vector<Elem> elem_of_each_node_;
|
std::vector<Elem> elem_of_each_node_;
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
#endif // XGBOOST_COMMON_ROW_SET_H_
|
#endif // XGBOOST_COMMON_ROW_SET_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user