Small cleanup to hist tree method. (#7735)
* Remove special optimization using number of bins. * Remove 1-based index for column sampling. * Remove data layout. * Unify update prediction cache.
This commit is contained in:
@@ -156,9 +156,8 @@ class ColumnSampler {
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel,
|
||||
float colsample_bytree, bool skip_index_0 = false) {
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
|
||||
float colsample_bylevel, float colsample_bytree) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
@@ -169,10 +168,8 @@ class ColumnSampler {
|
||||
}
|
||||
Reset();
|
||||
|
||||
int begin_idx = skip_index_0 ? 1 : 0;
|
||||
feature_set_tree_->Resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->HostVector().begin(),
|
||||
feature_set_tree_->HostVector().end(), begin_idx);
|
||||
feature_set_tree_->Resize(num_col);
|
||||
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
|
||||
|
||||
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||
}
|
||||
|
||||
@@ -55,8 +55,6 @@ class RowSetCollection {
|
||||
/*! \brief return corresponding element set given the node_id */
|
||||
inline const Elem& operator[](unsigned node_id) const {
|
||||
const Elem& e = elem_of_each_node_[node_id];
|
||||
CHECK(e.begin != nullptr)
|
||||
<< "access element that is not in the set";
|
||||
return e;
|
||||
}
|
||||
|
||||
@@ -75,14 +73,10 @@ class RowSetCollection {
|
||||
CHECK_EQ(elem_of_each_node_.size(), 0U);
|
||||
|
||||
if (row_indices_.empty()) { // edge case: empty instance set
|
||||
// assign arbitrary address here, to bypass nullptr check
|
||||
// (nullptr usually indicates a nonexistent rowset, but we want to
|
||||
// indicate a valid rowset that happens to have zero length and occupies
|
||||
// the whole instance set)
|
||||
// this is okay, as BuildHist will compute (end-begin) as the set size
|
||||
const size_t* begin = reinterpret_cast<size_t*>(20);
|
||||
const size_t* end = begin;
|
||||
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||
constexpr size_t* kBegin = nullptr;
|
||||
constexpr size_t* kEnd = nullptr;
|
||||
static_assert(kEnd - kBegin == 0, "");
|
||||
elem_of_each_node_.emplace_back(Elem(kBegin, kEnd, 0));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -93,15 +87,19 @@ class RowSetCollection {
|
||||
|
||||
std::vector<size_t>* Data() { return &row_indices_; }
|
||||
// split rowset into two
|
||||
inline void AddSplit(unsigned node_id,
|
||||
unsigned left_node_id,
|
||||
unsigned right_node_id,
|
||||
size_t n_left,
|
||||
size_t n_right) {
|
||||
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
|
||||
size_t n_left, size_t n_right) {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
CHECK(e.begin != nullptr);
|
||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
size_t* begin = all_begin + (e.begin - all_begin);
|
||||
|
||||
size_t* all_begin{nullptr};
|
||||
size_t* begin{nullptr};
|
||||
if (e.begin == nullptr) {
|
||||
CHECK_EQ(n_left, 0);
|
||||
CHECK_EQ(n_right, 0);
|
||||
} else {
|
||||
all_begin = dmlc::BeginPtr(row_indices_);
|
||||
begin = all_begin + (e.begin - all_begin);
|
||||
}
|
||||
|
||||
CHECK_EQ(n_left + n_right, e.Size());
|
||||
CHECK_LE(begin + n_left, e.end);
|
||||
|
||||
@@ -266,6 +266,9 @@ class MemStackAllocator {
|
||||
throw std::bad_alloc{};
|
||||
}
|
||||
}
|
||||
MemStackAllocator(size_t required_size, T init) : MemStackAllocator{required_size} {
|
||||
std::fill_n(ptr_, required_size_, init);
|
||||
}
|
||||
|
||||
~MemStackAllocator() {
|
||||
if (required_size_ > MaxStackSize) {
|
||||
|
||||
Reference in New Issue
Block a user