Define core multi-target regression tree structure. (#8884)

- Define a new tree struct embedded in the `RegTree`.
- Provide dispatching functions in `RegTree`.
- Fix some c++-17 warnings about the use of nodiscard (currently we disable the warning on
  the CI).
- Use uint32_t instead of size_t for `bst_target_t` as it has a defined size and can be used
  as part of dmlc parameter.
- Hide the `Segment` struct inside the categorical split matrix.
This commit is contained in:
Jiaming Yuan
2023-03-09 19:03:06 +08:00
committed by GitHub
parent 46dfcc7d22
commit 5feee8d4a9
16 changed files with 809 additions and 264 deletions

View File

@@ -12,7 +12,7 @@
#include "../../common/hist_util.h"
#include "../../data/gradient_index.h"
#include "expand_entry.h"
#include "xgboost/tree_model.h"
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost {
namespace tree {
@@ -175,8 +175,8 @@ class HistogramBuilder {
auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
if (!p_tree->IsRoot(entry.nid)) {
const size_t parent_id = p_tree->Parent(entry.nid);
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
@@ -213,8 +213,8 @@ class HistogramBuilder {
// Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot()) {
auto const parent_id = (*p_tree)[entry.nid].Parent();
if (!p_tree->IsRoot(entry.nid)) {
auto const parent_id = p_tree->Parent(entry.nid);
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
@@ -237,10 +237,10 @@ class HistogramBuilder {
common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes[node];
if (!((*p_tree)[entry.nid].IsLeftChild())) {
if (!(p_tree->IsLeftChild(entry.nid))) {
auto this_hist = this->hist_[entry.nid];
if (!(*p_tree)[entry.nid].IsRoot()) {
if (!p_tree->IsRoot(entry.nid)) {
const int subtraction_node_id = subtraction_nodes[node].nid;
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[subtraction_node_id];
@@ -285,7 +285,7 @@ class HistogramBuilder {
std::sort(merged_node_ids.begin(), merged_node_ids.end());
int n_left = 0;
for (auto const &nid : merged_node_ids) {
if ((*p_tree)[nid].IsLeftChild()) {
if (p_tree->IsLeftChild(nid)) {
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
n_left++;
@@ -293,7 +293,7 @@ class HistogramBuilder {
}
}
for (auto const &nid : merged_node_ids) {
if (!((*p_tree)[nid].IsLeftChild())) {
if (!(p_tree->IsLeftChild(nid))) {
this->hist_.AddHistRow(nid);
this->hist_local_worker_.AddHistRow(nid);
}