Define core multi-target regression tree structure. (#8884)
- Define a new tree struct embedded in the `RegTree`. - Provide dispatching functions in `RegTree`. - Fix some c++-17 warnings about the use of nodiscard (currently we disable the warning on the CI). - Use uint32_t instead of size_t for `bst_target_t` as it has a defined size and can be used as part of dmlc parameter. - Hide the `Segment` struct inside the categorical split matrix.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <GPUTreeShap/gpu_treeshap.h>
|
||||
#include <thrust/copy.h>
|
||||
@@ -25,9 +25,7 @@
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
|
||||
namespace xgboost::predictor {
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
struct TreeView {
|
||||
@@ -35,12 +33,11 @@ struct TreeView {
|
||||
common::Span<RegTree::Node const> d_tree;
|
||||
|
||||
XGBOOST_DEVICE
|
||||
TreeView(size_t tree_begin, size_t tree_idx,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
TreeView(size_t tree_begin, size_t tree_idx, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories) {
|
||||
auto begin = d_tree_segments[tree_idx - tree_begin];
|
||||
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
@@ -255,7 +252,7 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
@@ -290,7 +287,7 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<int const> d_tree_group,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features, size_t num_rows,
|
||||
size_t entry_start, bool use_shared, int num_group, float missing) {
|
||||
@@ -334,7 +331,7 @@ class DeviceModel {
|
||||
// Pointer to each tree, segmenting the node array.
|
||||
HostDeviceVector<uint32_t> categories_tree_segments;
|
||||
// Pointer to each node, segmenting categories array.
|
||||
HostDeviceVector<RegTree::Segment> categories_node_segments;
|
||||
HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment> categories_node_segments;
|
||||
HostDeviceVector<uint32_t> categories;
|
||||
|
||||
size_t tree_beg_; // NOLINT
|
||||
@@ -400,9 +397,9 @@ class DeviceModel {
|
||||
h_split_cat_segments.push_back(h_categories.size());
|
||||
}
|
||||
|
||||
categories_node_segments =
|
||||
HostDeviceVector<RegTree::Segment>(h_tree_segments.back(), {}, gpu_id);
|
||||
std::vector<RegTree::Segment> &h_categories_node_segments =
|
||||
categories_node_segments = HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment>(
|
||||
h_tree_segments.back(), {}, gpu_id);
|
||||
std::vector<RegTree::CategoricalSplitMatrix::Segment>& h_categories_node_segments =
|
||||
categories_node_segments.HostVector();
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr();
|
||||
@@ -542,10 +539,10 @@ void ExtractPaths(
|
||||
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
|
||||
common::IsCatOp{})) {
|
||||
dh::PinnedMemory pinned;
|
||||
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1);
|
||||
auto h_max_cat = pinned.GetSpan<RegTree::CategoricalSplitMatrix::Segment>(1);
|
||||
auto max_elem_it = dh::MakeTransformIterator<size_t>(
|
||||
dh::tbegin(d_cat_node_segments),
|
||||
[] __device__(RegTree::Segment seg) { return seg.size; });
|
||||
[] __device__(RegTree::CategoricalSplitMatrix::Segment seg) { return seg.size; });
|
||||
size_t max_cat_it =
|
||||
thrust::max_element(thrust::device, max_elem_it,
|
||||
max_elem_it + d_cat_node_segments.size()) -
|
||||
@@ -1028,5 +1025,4 @@ XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
.describe("Make predictions using GPU.")
|
||||
.set_body([](Context const* ctx) { return new GPUPredictor(ctx); });
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::predictor
|
||||
|
||||
Reference in New Issue
Block a user