Define core multi-target regression tree structure. (#8884)
- Define a new tree struct embedded in the `RegTree`. - Provide dispatching functions in `RegTree`. - Fix some c++-17 warnings about the use of nodiscard (currently we disable the warning on the CI). - Use uint32_t instead of size_t for `bst_target_t` as it has a defined size and can be used as part of dmlc parameter. - Hide the `Segment` struct inside the categorical split matrix.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
@@ -179,7 +179,6 @@ template class HostDeviceVector<FeatureType>;
|
||||
template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Segment>;
|
||||
|
||||
#if defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
||||
/*
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost contributors
|
||||
*/
|
||||
|
||||
#include <thrust/fill.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
|
||||
@@ -412,7 +411,7 @@ template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Node>;
|
||||
template class HostDeviceVector<RegTree::Segment>;
|
||||
template class HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment>;
|
||||
template class HostDeviceVector<RTreeNodeStat>;
|
||||
|
||||
#if defined(__APPLE__)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <GPUTreeShap/gpu_treeshap.h>
|
||||
#include <thrust/copy.h>
|
||||
@@ -25,9 +25,7 @@
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
|
||||
namespace xgboost::predictor {
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
struct TreeView {
|
||||
@@ -35,12 +33,11 @@ struct TreeView {
|
||||
common::Span<RegTree::Node const> d_tree;
|
||||
|
||||
XGBOOST_DEVICE
|
||||
TreeView(size_t tree_begin, size_t tree_idx,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
TreeView(size_t tree_begin, size_t tree_idx, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories) {
|
||||
auto begin = d_tree_segments[tree_idx - tree_begin];
|
||||
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
@@ -255,7 +252,7 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
@@ -290,7 +287,7 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<int const> d_tree_group,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features, size_t num_rows,
|
||||
size_t entry_start, bool use_shared, int num_group, float missing) {
|
||||
@@ -334,7 +331,7 @@ class DeviceModel {
|
||||
// Pointer to each tree, segmenting the node array.
|
||||
HostDeviceVector<uint32_t> categories_tree_segments;
|
||||
// Pointer to each node, segmenting categories array.
|
||||
HostDeviceVector<RegTree::Segment> categories_node_segments;
|
||||
HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment> categories_node_segments;
|
||||
HostDeviceVector<uint32_t> categories;
|
||||
|
||||
size_t tree_beg_; // NOLINT
|
||||
@@ -400,9 +397,9 @@ class DeviceModel {
|
||||
h_split_cat_segments.push_back(h_categories.size());
|
||||
}
|
||||
|
||||
categories_node_segments =
|
||||
HostDeviceVector<RegTree::Segment>(h_tree_segments.back(), {}, gpu_id);
|
||||
std::vector<RegTree::Segment> &h_categories_node_segments =
|
||||
categories_node_segments = HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment>(
|
||||
h_tree_segments.back(), {}, gpu_id);
|
||||
std::vector<RegTree::CategoricalSplitMatrix::Segment>& h_categories_node_segments =
|
||||
categories_node_segments.HostVector();
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr();
|
||||
@@ -542,10 +539,10 @@ void ExtractPaths(
|
||||
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
|
||||
common::IsCatOp{})) {
|
||||
dh::PinnedMemory pinned;
|
||||
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1);
|
||||
auto h_max_cat = pinned.GetSpan<RegTree::CategoricalSplitMatrix::Segment>(1);
|
||||
auto max_elem_it = dh::MakeTransformIterator<size_t>(
|
||||
dh::tbegin(d_cat_node_segments),
|
||||
[] __device__(RegTree::Segment seg) { return seg.size; });
|
||||
[] __device__(RegTree::CategoricalSplitMatrix::Segment seg) { return seg.size; });
|
||||
size_t max_cat_it =
|
||||
thrust::max_element(thrust::device, max_elem_it,
|
||||
max_elem_it + d_cat_node_segments.size()) -
|
||||
@@ -1028,5 +1025,4 @@ XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
.describe("Make predictions using GPU.")
|
||||
.set_body([](Context const* ctx) { return new GPUPredictor(ctx); });
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::predictor
|
||||
|
||||
@@ -71,10 +71,7 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
||||
auto n_samples = gpair.Size() / n_targets;
|
||||
|
||||
gpair.SetDevice(ctx->gpu_id);
|
||||
linalg::TensorView<GradientPair const, 2> gpair_t{
|
||||
ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(),
|
||||
{n_samples, n_targets},
|
||||
ctx->gpu_id};
|
||||
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "../../common/hist_util.h"
|
||||
#include "../../data/gradient_index.h"
|
||||
#include "expand_entry.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -175,8 +175,8 @@ class HistogramBuilder {
|
||||
auto this_local = hist_local_worker_[entry.nid];
|
||||
common::CopyHist(this_local, this_hist, r.begin(), r.end());
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||
if (!p_tree->IsRoot(entry.nid)) {
|
||||
const size_t parent_id = p_tree->Parent(entry.nid);
|
||||
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
|
||||
auto parent_hist = this->hist_local_worker_[parent_id];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
@@ -213,8 +213,8 @@ class HistogramBuilder {
|
||||
// Merging histograms from each thread into once
|
||||
this->buffer_.ReduceHist(node, r.begin(), r.end());
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
auto const parent_id = (*p_tree)[entry.nid].Parent();
|
||||
if (!p_tree->IsRoot(entry.nid)) {
|
||||
auto const parent_id = p_tree->Parent(entry.nid);
|
||||
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
@@ -237,10 +237,10 @@ class HistogramBuilder {
|
||||
common::ParallelFor2d(
|
||||
space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
||||
const auto &entry = nodes[node];
|
||||
if (!((*p_tree)[entry.nid].IsLeftChild())) {
|
||||
if (!(p_tree->IsLeftChild(entry.nid))) {
|
||||
auto this_hist = this->hist_[entry.nid];
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
if (!p_tree->IsRoot(entry.nid)) {
|
||||
const int subtraction_node_id = subtraction_nodes[node].nid;
|
||||
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
|
||||
auto sibling_hist = hist_[subtraction_node_id];
|
||||
@@ -285,7 +285,7 @@ class HistogramBuilder {
|
||||
std::sort(merged_node_ids.begin(), merged_node_ids.end());
|
||||
int n_left = 0;
|
||||
for (auto const &nid : merged_node_ids) {
|
||||
if ((*p_tree)[nid].IsLeftChild()) {
|
||||
if (p_tree->IsLeftChild(nid)) {
|
||||
this->hist_.AddHistRow(nid);
|
||||
(*starting_index) = std::min(nid, (*starting_index));
|
||||
n_left++;
|
||||
@@ -293,7 +293,7 @@ class HistogramBuilder {
|
||||
}
|
||||
}
|
||||
for (auto const &nid : merged_node_ids) {
|
||||
if (!((*p_tree)[nid].IsLeftChild())) {
|
||||
if (!(p_tree->IsLeftChild(nid))) {
|
||||
this->hist_.AddHistRow(nid);
|
||||
this->hist_local_worker_.AddHistRow(nid);
|
||||
}
|
||||
|
||||
65
src/tree/io_utils.h
Normal file
65
src/tree/io_utils.h
Normal file
@@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_IO_UTILS_H_
|
||||
#define XGBOOST_TREE_IO_UTILS_H_
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for enable_if_t, is_same, conditional_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost {
|
||||
template <bool typed>
|
||||
using FloatArrayT = std::conditional_t<typed, F32Array const, Array const>;
|
||||
template <bool typed>
|
||||
using U8ArrayT = std::conditional_t<typed, U8Array const, Array const>;
|
||||
template <bool typed>
|
||||
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
|
||||
template <bool typed>
|
||||
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
|
||||
template <bool typed, bool feature_is_64>
|
||||
using IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT<typed>, I32ArrayT<typed>>;
|
||||
|
||||
// typed array, not boolean
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
|
||||
std::vector<T> const& arr, size_t i) {
|
||||
return arr[i];
|
||||
}
|
||||
// typed array boolean
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
|
||||
std::is_same<JT, Boolean>::value,
|
||||
bool>
|
||||
GetElem(std::vector<T> const& arr, size_t i) {
|
||||
return arr[i] == 1;
|
||||
}
|
||||
// json array
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<
|
||||
std::is_same<T, Json>::value,
|
||||
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
|
||||
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
|
||||
GetElem(std::vector<T> const& arr, size_t i) {
|
||||
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
|
||||
return get<Integer const>(arr[i]) == 1;
|
||||
}
|
||||
return get<JT const>(arr[i]);
|
||||
}
|
||||
|
||||
namespace tree_field {
|
||||
inline std::string const kLossChg{"loss_changes"};
|
||||
inline std::string const kSumHess{"sum_hessian"};
|
||||
inline std::string const kBaseWeight{"base_weights"};
|
||||
|
||||
inline std::string const kSplitIdx{"split_indices"};
|
||||
inline std::string const kSplitCond{"split_conditions"};
|
||||
inline std::string const kDftLeft{"default_left"};
|
||||
|
||||
inline std::string const kParent{"parents"};
|
||||
inline std::string const kLeft{"left_children"};
|
||||
inline std::string const kRight{"right_children"};
|
||||
} // namespace tree_field
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_IO_UTILS_H_
|
||||
220
src/tree/multi_target_tree_model.cc
Normal file
220
src/tree/multi_target_tree_model.cc
Normal file
@@ -0,0 +1,220 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/multi_target_tree_model.h"
|
||||
|
||||
#include <algorithm> // for copy_n
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, uint8_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string_view> // for string_view
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
|
||||
#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
|
||||
#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ...
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/tree_model.h" // for TreeParam
|
||||
|
||||
namespace xgboost {
|
||||
MultiTargetTree::MultiTargetTree(TreeParam const* param)
|
||||
: param_{param},
|
||||
left_(1ul, InvalidNodeId()),
|
||||
right_(1ul, InvalidNodeId()),
|
||||
parent_(1ul, InvalidNodeId()),
|
||||
split_index_(1ul, 0),
|
||||
default_left_(1ul, 0),
|
||||
split_conds_(1ul, std::numeric_limits<float>::quiet_NaN()),
|
||||
weights_(param->size_leaf_vector, std::numeric_limits<float>::quiet_NaN()) {
|
||||
CHECK_GT(param_->size_leaf_vector, 1);
|
||||
}
|
||||
|
||||
template <bool typed, bool feature_is_64>
|
||||
void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bst_node_t>* p_lefts,
|
||||
std::vector<bst_node_t>* p_rights, std::vector<bst_node_t>* p_parents,
|
||||
std::vector<float>* p_conds, std::vector<bst_feature_t>* p_fidx,
|
||||
std::vector<std::uint8_t>* p_dft_left) {
|
||||
namespace tf = tree_field;
|
||||
|
||||
auto get_float = [&](std::string_view name, std::vector<float>* p_out) {
|
||||
auto& values = get<FloatArrayT<typed>>(get<Object const>(in).find(name)->second);
|
||||
auto& out = *p_out;
|
||||
out.resize(values.size());
|
||||
for (std::size_t i = 0; i < values.size(); ++i) {
|
||||
out[i] = GetElem<Number>(values, i);
|
||||
}
|
||||
};
|
||||
get_float(tf::kBaseWeight, p_weights);
|
||||
get_float(tf::kSplitCond, p_conds);
|
||||
|
||||
auto get_nidx = [&](std::string_view name, std::vector<bst_node_t>* p_nidx) {
|
||||
auto& nidx = get<I32ArrayT<typed>>(get<Object const>(in).find(name)->second);
|
||||
auto& out_nidx = *p_nidx;
|
||||
out_nidx.resize(nidx.size());
|
||||
for (std::size_t i = 0; i < nidx.size(); ++i) {
|
||||
out_nidx[i] = GetElem<Integer>(nidx, i);
|
||||
}
|
||||
};
|
||||
get_nidx(tf::kLeft, p_lefts);
|
||||
get_nidx(tf::kRight, p_rights);
|
||||
get_nidx(tf::kParent, p_parents);
|
||||
|
||||
auto const& splits = get<IndexArrayT<typed, feature_is_64> const>(in[tf::kSplitIdx]);
|
||||
p_fidx->resize(splits.size());
|
||||
auto& out_fidx = *p_fidx;
|
||||
for (std::size_t i = 0; i < splits.size(); ++i) {
|
||||
out_fidx[i] = GetElem<Integer>(splits, i);
|
||||
}
|
||||
|
||||
auto const& dft_left = get<U8ArrayT<typed> const>(in[tf::kDftLeft]);
|
||||
auto& out_dft_l = *p_dft_left;
|
||||
out_dft_l.resize(dft_left.size());
|
||||
for (std::size_t i = 0; i < dft_left.size(); ++i) {
|
||||
out_dft_l[i] = GetElem<Boolean>(dft_left, i);
|
||||
}
|
||||
}
|
||||
|
||||
void MultiTargetTree::LoadModel(Json const& in) {
|
||||
namespace tf = tree_field;
|
||||
bool typed = IsA<F32Array>(in[tf::kBaseWeight]);
|
||||
bool feature_is_64 = IsA<I64Array>(in[tf::kSplitIdx]);
|
||||
|
||||
if (typed && feature_is_64) {
|
||||
LoadModelImpl<true, true>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
|
||||
&split_index_, &default_left_);
|
||||
} else if (typed && !feature_is_64) {
|
||||
LoadModelImpl<true, false>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
|
||||
&split_index_, &default_left_);
|
||||
} else if (!typed && feature_is_64) {
|
||||
LoadModelImpl<false, true>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
|
||||
&split_index_, &default_left_);
|
||||
} else {
|
||||
LoadModelImpl<false, false>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
|
||||
&split_index_, &default_left_);
|
||||
}
|
||||
}
|
||||
|
||||
void MultiTargetTree::SaveModel(Json* p_out) const {
|
||||
CHECK(p_out);
|
||||
auto& out = *p_out;
|
||||
|
||||
auto n_nodes = param_->num_nodes;
|
||||
|
||||
// nodes
|
||||
I32Array lefts(n_nodes);
|
||||
I32Array rights(n_nodes);
|
||||
I32Array parents(n_nodes);
|
||||
F32Array conds(n_nodes);
|
||||
U8Array default_left(n_nodes);
|
||||
F32Array weights(n_nodes * this->NumTarget());
|
||||
|
||||
auto save_tree = [&](auto* p_indices_array) {
|
||||
auto& indices_array = *p_indices_array;
|
||||
for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) {
|
||||
CHECK_LT(nidx, left_.size());
|
||||
lefts.Set(nidx, left_[nidx]);
|
||||
CHECK_LT(nidx, right_.size());
|
||||
rights.Set(nidx, right_[nidx]);
|
||||
CHECK_LT(nidx, parent_.size());
|
||||
parents.Set(nidx, parent_[nidx]);
|
||||
CHECK_LT(nidx, split_index_.size());
|
||||
indices_array.Set(nidx, split_index_[nidx]);
|
||||
conds.Set(nidx, split_conds_[nidx]);
|
||||
default_left.Set(nidx, default_left_[nidx]);
|
||||
|
||||
auto in_weight = this->NodeWeight(nidx);
|
||||
auto weight_out = common::Span<float>(weights.GetArray())
|
||||
.subspan(nidx * this->NumTarget(), this->NumTarget());
|
||||
CHECK_EQ(in_weight.Size(), weight_out.size());
|
||||
std::copy_n(in_weight.Values().data(), in_weight.Size(), weight_out.data());
|
||||
}
|
||||
};
|
||||
|
||||
namespace tf = tree_field;
|
||||
|
||||
if (this->param_->num_feature >
|
||||
static_cast<bst_feature_t>(std::numeric_limits<std::int32_t>::max())) {
|
||||
I64Array indices_64(n_nodes);
|
||||
save_tree(&indices_64);
|
||||
out[tf::kSplitIdx] = std::move(indices_64);
|
||||
} else {
|
||||
I32Array indices_32(n_nodes);
|
||||
save_tree(&indices_32);
|
||||
out[tf::kSplitIdx] = std::move(indices_32);
|
||||
}
|
||||
|
||||
out[tf::kBaseWeight] = std::move(weights);
|
||||
out[tf::kLeft] = std::move(lefts);
|
||||
out[tf::kRight] = std::move(rights);
|
||||
out[tf::kParent] = std::move(parents);
|
||||
|
||||
out[tf::kSplitCond] = std::move(conds);
|
||||
out[tf::kDftLeft] = std::move(default_left);
|
||||
}
|
||||
|
||||
void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
|
||||
CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented();
|
||||
auto const next_nidx = nidx + 1;
|
||||
CHECK_EQ(weight.Size(), this->NumTarget());
|
||||
CHECK_GE(weights_.size(), next_nidx * weight.Size());
|
||||
auto out_weight = common::Span<float>(weights_).subspan(nidx * weight.Size(), weight.Size());
|
||||
for (std::size_t i = 0; i < weight.Size(); ++i) {
|
||||
out_weight[i] = weight(i);
|
||||
}
|
||||
}
|
||||
|
||||
void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond,
|
||||
bool default_left, linalg::VectorView<float const> base_weight,
|
||||
linalg::VectorView<float const> left_weight,
|
||||
linalg::VectorView<float const> right_weight) {
|
||||
CHECK(this->IsLeaf(nidx));
|
||||
CHECK_GE(parent_.size(), 1);
|
||||
CHECK_EQ(parent_.size(), left_.size());
|
||||
CHECK_EQ(left_.size(), right_.size());
|
||||
|
||||
std::size_t n = param_->num_nodes + 2;
|
||||
CHECK_LT(split_idx, this->param_->num_feature);
|
||||
left_.resize(n, InvalidNodeId());
|
||||
right_.resize(n, InvalidNodeId());
|
||||
parent_.resize(n, InvalidNodeId());
|
||||
|
||||
auto left_child = parent_.size() - 2;
|
||||
auto right_child = parent_.size() - 1;
|
||||
|
||||
left_[nidx] = left_child;
|
||||
right_[nidx] = right_child;
|
||||
|
||||
if (nidx != 0) {
|
||||
CHECK_NE(parent_[nidx], InvalidNodeId());
|
||||
}
|
||||
|
||||
parent_[left_child] = nidx;
|
||||
parent_[right_child] = nidx;
|
||||
|
||||
split_index_.resize(n);
|
||||
split_index_[nidx] = split_idx;
|
||||
|
||||
split_conds_.resize(n);
|
||||
split_conds_[nidx] = split_cond;
|
||||
default_left_.resize(n);
|
||||
default_left_[nidx] = static_cast<std::uint8_t>(default_left);
|
||||
|
||||
weights_.resize(n * this->NumTarget());
|
||||
auto p_weight = this->NodeWeight(nidx);
|
||||
CHECK_EQ(p_weight.Size(), base_weight.Size());
|
||||
auto l_weight = this->NodeWeight(left_child);
|
||||
CHECK_EQ(l_weight.Size(), left_weight.Size());
|
||||
auto r_weight = this->NodeWeight(right_child);
|
||||
CHECK_EQ(r_weight.Size(), right_weight.Size());
|
||||
|
||||
for (std::size_t i = 0; i < base_weight.Size(); ++i) {
|
||||
p_weight(i) = base_weight(i);
|
||||
l_weight(i) = left_weight(i);
|
||||
r_weight(i) = right_weight(i);
|
||||
}
|
||||
}
|
||||
|
||||
bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; }
|
||||
std::size_t MultiTargetTree::Size() const { return parent_.size(); }
|
||||
} // namespace xgboost
|
||||
@@ -1,25 +1,27 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by Contributors
|
||||
* \file tree_model.cc
|
||||
* \brief model structure for tree
|
||||
*/
|
||||
#include <dmlc/registry.h>
|
||||
#include <dmlc/json.h>
|
||||
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <stack>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
#include "io_utils.h" // GetElem
|
||||
#include "param.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
// register tree parameter
|
||||
@@ -729,12 +731,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
|
||||
|
||||
constexpr bst_node_t RegTree::kRoot;
|
||||
|
||||
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const {
|
||||
std::unique_ptr<TreeGenerator> builder {
|
||||
TreeGenerator::Create(format, fmap, with_stats)
|
||||
};
|
||||
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
|
||||
CHECK(!IsMultiTarget());
|
||||
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
|
||||
builder->BuildTree(*this);
|
||||
|
||||
std::string result = builder->Str();
|
||||
@@ -742,6 +741,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
}
|
||||
|
||||
bool RegTree::Equal(const RegTree& b) const {
|
||||
CHECK(!IsMultiTarget());
|
||||
if (NumExtraNodes() != b.NumExtraNodes()) {
|
||||
return false;
|
||||
}
|
||||
@@ -758,6 +758,7 @@ bool RegTree::Equal(const RegTree& b) const {
|
||||
}
|
||||
|
||||
bst_node_t RegTree::GetNumLeaves() const {
|
||||
CHECK(!IsMultiTarget());
|
||||
bst_node_t leaves { 0 };
|
||||
auto const& self = *this;
|
||||
this->WalkTree([&leaves, &self](bst_node_t nidx) {
|
||||
@@ -770,6 +771,7 @@ bst_node_t RegTree::GetNumLeaves() const {
|
||||
}
|
||||
|
||||
bst_node_t RegTree::GetNumSplitNodes() const {
|
||||
CHECK(!IsMultiTarget());
|
||||
bst_node_t splits { 0 };
|
||||
auto const& self = *this;
|
||||
this->WalkTree([&splits, &self](bst_node_t nidx) {
|
||||
@@ -787,6 +789,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
|
||||
bst_float right_leaf_weight, bst_float loss_change,
|
||||
float sum_hess, float left_sum, float right_sum,
|
||||
bst_node_t leaf_right_child) {
|
||||
CHECK(!IsMultiTarget());
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
auto &node = nodes_[nid];
|
||||
@@ -807,11 +810,31 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
|
||||
this->split_types_.at(nid) = FeatureType::kNumerical;
|
||||
}
|
||||
|
||||
void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond,
|
||||
bool default_left, linalg::VectorView<float const> base_weight,
|
||||
linalg::VectorView<float const> left_weight,
|
||||
linalg::VectorView<float const> right_weight) {
|
||||
CHECK(IsMultiTarget());
|
||||
CHECK_LT(split_index, this->param.num_feature);
|
||||
CHECK(this->p_mt_tree_);
|
||||
CHECK_GT(param.size_leaf_vector, 1);
|
||||
|
||||
this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight,
|
||||
right_weight);
|
||||
|
||||
split_types_.resize(this->Size(), FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(this->Size());
|
||||
this->split_types_.at(nidx) = FeatureType::kNumerical;
|
||||
|
||||
this->param.num_nodes = this->p_mt_tree_->Size();
|
||||
}
|
||||
|
||||
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
||||
common::Span<const uint32_t> split_cat, bool default_left,
|
||||
bst_float base_weight, bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||
float left_sum, float right_sum) {
|
||||
CHECK(!IsMultiTarget());
|
||||
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
|
||||
default_left, base_weight,
|
||||
left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
|
||||
@@ -893,44 +916,17 @@ void RegTree::Save(dmlc::Stream* fo) const {
|
||||
}
|
||||
}
|
||||
}
|
||||
// typed array, not boolean
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
|
||||
std::vector<T> const& arr, size_t i) {
|
||||
return arr[i];
|
||||
}
|
||||
// typed array boolean
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
|
||||
std::is_same<JT, Boolean>::value,
|
||||
bool>
|
||||
GetElem(std::vector<T> const& arr, size_t i) {
|
||||
return arr[i] == 1;
|
||||
}
|
||||
// json array
|
||||
template <typename JT, typename T>
|
||||
std::enable_if_t<
|
||||
std::is_same<T, Json>::value,
|
||||
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
|
||||
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
|
||||
GetElem(std::vector<T> const& arr, size_t i) {
|
||||
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
|
||||
return get<Integer const>(arr[i]) == 1;
|
||||
}
|
||||
return get<JT const>(arr[i]);
|
||||
}
|
||||
|
||||
template <bool typed>
|
||||
void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
|
||||
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
|
||||
auto const& categories_segments = get<I64ArrayT<typed>>(in["categories_segments"]);
|
||||
auto const& categories_sizes = get<I64ArrayT<typed>>(in["categories_sizes"]);
|
||||
auto const& categories_nodes = get<I32ArrayT<typed>>(in["categories_nodes"]);
|
||||
auto const& categories = get<I32ArrayT<typed>>(in["categories"]);
|
||||
|
||||
auto const& categories_segments = get<I64ArrayT>(in["categories_segments"]);
|
||||
auto const& categories_sizes = get<I64ArrayT>(in["categories_sizes"]);
|
||||
auto const& categories_nodes = get<I32ArrayT>(in["categories_nodes"]);
|
||||
auto const& categories = get<I32ArrayT>(in["categories"]);
|
||||
|
||||
size_t cnt = 0;
|
||||
auto split_type = get<U8ArrayT<typed>>(in["split_type"]);
|
||||
bst_node_t n_nodes = split_type.size();
|
||||
std::size_t cnt = 0;
|
||||
bst_node_t last_cat_node = -1;
|
||||
if (!categories_nodes.empty()) {
|
||||
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
|
||||
@@ -938,7 +934,10 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
// `categories_segments' is only available for categorical nodes to prevent overhead for
|
||||
// numerical node. As a result, we need to track the categorical nodes we have processed
|
||||
// so far.
|
||||
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
|
||||
split_types_.resize(n_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(n_nodes);
|
||||
for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) {
|
||||
split_types_[nidx] = static_cast<FeatureType>(GetElem<Integer>(split_type, nidx));
|
||||
if (nidx == last_cat_node) {
|
||||
auto j_begin = GetElem<Integer>(categories_segments, cnt);
|
||||
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
|
||||
@@ -985,15 +984,17 @@ template void RegTree::LoadCategoricalSplit<false>(Json const& in);
|
||||
|
||||
void RegTree::SaveCategoricalSplit(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(this->split_types_.size(), param.num_nodes);
|
||||
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
|
||||
CHECK_EQ(this->split_types_.size(), this->Size());
|
||||
CHECK_EQ(this->GetSplitCategoriesPtr().size(), this->Size());
|
||||
|
||||
I64Array categories_segments;
|
||||
I64Array categories_sizes;
|
||||
I32Array categories; // bst_cat_t = int32_t
|
||||
I32Array categories_nodes; // bst_note_t = int32_t
|
||||
U8Array split_type(split_types_.size());
|
||||
|
||||
for (size_t i = 0; i < nodes_.size(); ++i) {
|
||||
split_type.Set(i, static_cast<std::underlying_type_t<FeatureType>>(this->NodeSplitType(i)));
|
||||
if (this->split_types_[i] == FeatureType::kCategorical) {
|
||||
categories_nodes.GetArray().emplace_back(i);
|
||||
auto begin = categories.Size();
|
||||
@@ -1012,66 +1013,49 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
|
||||
}
|
||||
}
|
||||
|
||||
out["split_type"] = std::move(split_type);
|
||||
out["categories_segments"] = std::move(categories_segments);
|
||||
out["categories_sizes"] = std::move(categories_sizes);
|
||||
out["categories_nodes"] = std::move(categories_nodes);
|
||||
out["categories"] = std::move(categories);
|
||||
}
|
||||
|
||||
template <bool typed, bool feature_is_64,
|
||||
typename FloatArrayT = std::conditional_t<typed, F32Array const, Array const>,
|
||||
typename U8ArrayT = std::conditional_t<typed, U8Array const, Array const>,
|
||||
typename I32ArrayT = std::conditional_t<typed, I32Array const, Array const>,
|
||||
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>,
|
||||
typename IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT, I32ArrayT>>
|
||||
bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>* p_stats,
|
||||
std::vector<FeatureType>* p_split_types, std::vector<RegTree::Node>* p_nodes,
|
||||
std::vector<RegTree::Segment>* p_split_categories_segments) {
|
||||
template <bool typed, bool feature_is_64>
|
||||
void LoadModelImpl(Json const& in, TreeParam const& param, std::vector<RTreeNodeStat>* p_stats,
|
||||
std::vector<RegTree::Node>* p_nodes) {
|
||||
namespace tf = tree_field;
|
||||
auto& stats = *p_stats;
|
||||
auto& split_types = *p_split_types;
|
||||
auto& nodes = *p_nodes;
|
||||
auto& split_categories_segments = *p_split_categories_segments;
|
||||
|
||||
FromJson(in["tree_param"], param);
|
||||
auto n_nodes = param->num_nodes;
|
||||
auto n_nodes = param.num_nodes;
|
||||
CHECK_NE(n_nodes, 0);
|
||||
// stats
|
||||
auto const& loss_changes = get<FloatArrayT>(in["loss_changes"]);
|
||||
auto const& loss_changes = get<FloatArrayT<typed>>(in[tf::kLossChg]);
|
||||
CHECK_EQ(loss_changes.size(), n_nodes);
|
||||
auto const& sum_hessian = get<FloatArrayT>(in["sum_hessian"]);
|
||||
auto const& sum_hessian = get<FloatArrayT<typed>>(in[tf::kSumHess]);
|
||||
CHECK_EQ(sum_hessian.size(), n_nodes);
|
||||
auto const& base_weights = get<FloatArrayT>(in["base_weights"]);
|
||||
auto const& base_weights = get<FloatArrayT<typed>>(in[tf::kBaseWeight]);
|
||||
CHECK_EQ(base_weights.size(), n_nodes);
|
||||
// nodes
|
||||
auto const& lefts = get<I32ArrayT>(in["left_children"]);
|
||||
auto const& lefts = get<I32ArrayT<typed>>(in[tf::kLeft]);
|
||||
CHECK_EQ(lefts.size(), n_nodes);
|
||||
auto const& rights = get<I32ArrayT>(in["right_children"]);
|
||||
auto const& rights = get<I32ArrayT<typed>>(in[tf::kRight]);
|
||||
CHECK_EQ(rights.size(), n_nodes);
|
||||
auto const& parents = get<I32ArrayT>(in["parents"]);
|
||||
auto const& parents = get<I32ArrayT<typed>>(in[tf::kParent]);
|
||||
CHECK_EQ(parents.size(), n_nodes);
|
||||
auto const& indices = get<IndexArrayT>(in["split_indices"]);
|
||||
auto const& indices = get<IndexArrayT<typed, feature_is_64>>(in[tf::kSplitIdx]);
|
||||
CHECK_EQ(indices.size(), n_nodes);
|
||||
auto const& conds = get<FloatArrayT>(in["split_conditions"]);
|
||||
auto const& conds = get<FloatArrayT<typed>>(in[tf::kSplitCond]);
|
||||
CHECK_EQ(conds.size(), n_nodes);
|
||||
auto const& default_left = get<U8ArrayT>(in["default_left"]);
|
||||
auto const& default_left = get<U8ArrayT<typed>>(in[tf::kDftLeft]);
|
||||
CHECK_EQ(default_left.size(), n_nodes);
|
||||
|
||||
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
|
||||
std::remove_const_t<std::remove_reference_t<decltype(get<U8ArrayT const>(in["split_type"]))>>
|
||||
split_type;
|
||||
if (has_cat) {
|
||||
split_type = get<U8ArrayT const>(in["split_type"]);
|
||||
}
|
||||
|
||||
// Initialization
|
||||
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
|
||||
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
|
||||
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
|
||||
split_categories_segments = std::remove_reference_t<decltype(split_categories_segments)>(n_nodes);
|
||||
|
||||
static_assert(std::is_integral<decltype(GetElem<Integer>(lefts, 0))>::value);
|
||||
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value);
|
||||
CHECK_EQ(n_nodes, split_categories_segments.size());
|
||||
|
||||
// Set node
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
@@ -1088,41 +1072,46 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
|
||||
float cond{GetElem<Number>(conds, i)};
|
||||
bool dft_left{GetElem<Boolean>(default_left, i)};
|
||||
n = RegTree::Node{left, right, parent, ind, cond, dft_left};
|
||||
|
||||
if (has_cat) {
|
||||
split_types[i] = static_cast<FeatureType>(GetElem<Integer>(split_type, i));
|
||||
}
|
||||
}
|
||||
|
||||
return has_cat;
|
||||
}
|
||||
|
||||
void RegTree::LoadModel(Json const& in) {
|
||||
bool has_cat{false};
|
||||
bool typed = IsA<F32Array>(in["loss_changes"]);
|
||||
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
|
||||
if (typed && feature_is_64) {
|
||||
has_cat = LoadModelImpl<true, true>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
} else if (typed && !feature_is_64) {
|
||||
has_cat = LoadModelImpl<true, false>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
} else if (!typed && feature_is_64) {
|
||||
has_cat = LoadModelImpl<false, true>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
} else {
|
||||
has_cat = LoadModelImpl<false, false>(in, ¶m, &stats_, &split_types_, &nodes_,
|
||||
&split_categories_segments_);
|
||||
}
|
||||
namespace tf = tree_field;
|
||||
|
||||
bool typed = IsA<I32Array>(in[tf::kParent]);
|
||||
auto const& in_obj = get<Object const>(in);
|
||||
// basic properties
|
||||
FromJson(in["tree_param"], ¶m);
|
||||
// categorical splits
|
||||
bool has_cat = in_obj.find("split_type") != in_obj.cend();
|
||||
if (has_cat) {
|
||||
if (typed) {
|
||||
this->LoadCategoricalSplit<true>(in);
|
||||
} else {
|
||||
this->LoadCategoricalSplit<false>(in);
|
||||
}
|
||||
}
|
||||
// multi-target
|
||||
if (param.size_leaf_vector > 1) {
|
||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
||||
this->GetMultiTargetTree()->LoadModel(in);
|
||||
return;
|
||||
}
|
||||
|
||||
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
|
||||
if (typed && feature_is_64) {
|
||||
LoadModelImpl<true, true>(in, param, &stats_, &nodes_);
|
||||
} else if (typed && !feature_is_64) {
|
||||
LoadModelImpl<true, false>(in, param, &stats_, &nodes_);
|
||||
} else if (!typed && feature_is_64) {
|
||||
LoadModelImpl<false, true>(in, param, &stats_, &nodes_);
|
||||
} else {
|
||||
LoadModelImpl<false, false>(in, param, &stats_, &nodes_);
|
||||
}
|
||||
|
||||
if (!has_cat) {
|
||||
this->split_categories_segments_.resize(this->param.num_nodes);
|
||||
this->split_types_.resize(this->param.num_nodes);
|
||||
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
|
||||
}
|
||||
|
||||
@@ -1144,16 +1133,26 @@ void RegTree::LoadModel(Json const& in) {
|
||||
}
|
||||
|
||||
void RegTree::SaveModel(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
// basic properties
|
||||
out["tree_param"] = ToJson(param);
|
||||
// categorical splits
|
||||
this->SaveCategoricalSplit(p_out);
|
||||
// multi-target
|
||||
if (this->IsMultiTarget()) {
|
||||
CHECK_GT(param.size_leaf_vector, 1);
|
||||
this->GetMultiTargetTree()->SaveModel(p_out);
|
||||
return;
|
||||
}
|
||||
/* Here we are treating leaf node and internal node equally. Some information like
|
||||
* child node id doesn't make sense for leaf node but we will have to save them to
|
||||
* avoid creating a huge map. One difficulty is XGBoost has deleted node created by
|
||||
* pruner, and this pruner can be used inside another updater so leaf are not necessary
|
||||
* at the end of node array.
|
||||
*/
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
out["tree_param"] = ToJson(param);
|
||||
|
||||
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
|
||||
auto n_nodes = param.num_nodes;
|
||||
|
||||
@@ -1167,12 +1166,12 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
I32Array rights(n_nodes);
|
||||
I32Array parents(n_nodes);
|
||||
|
||||
|
||||
F32Array conds(n_nodes);
|
||||
U8Array default_left(n_nodes);
|
||||
U8Array split_type(n_nodes);
|
||||
CHECK_EQ(this->split_types_.size(), param.num_nodes);
|
||||
|
||||
namespace tf = tree_field;
|
||||
|
||||
auto save_tree = [&](auto* p_indices_array) {
|
||||
auto& indices_array = *p_indices_array;
|
||||
for (bst_node_t i = 0; i < n_nodes; ++i) {
|
||||
@@ -1188,33 +1187,28 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
indices_array.Set(i, n.SplitIndex());
|
||||
conds.Set(i, n.SplitCond());
|
||||
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
|
||||
|
||||
split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
|
||||
}
|
||||
};
|
||||
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
|
||||
I64Array indices_64(n_nodes);
|
||||
save_tree(&indices_64);
|
||||
out["split_indices"] = std::move(indices_64);
|
||||
out[tf::kSplitIdx] = std::move(indices_64);
|
||||
} else {
|
||||
I32Array indices_32(n_nodes);
|
||||
save_tree(&indices_32);
|
||||
out["split_indices"] = std::move(indices_32);
|
||||
out[tf::kSplitIdx] = std::move(indices_32);
|
||||
}
|
||||
|
||||
this->SaveCategoricalSplit(&out);
|
||||
out[tf::kLossChg] = std::move(loss_changes);
|
||||
out[tf::kSumHess] = std::move(sum_hessian);
|
||||
out[tf::kBaseWeight] = std::move(base_weights);
|
||||
|
||||
out["split_type"] = std::move(split_type);
|
||||
out["loss_changes"] = std::move(loss_changes);
|
||||
out["sum_hessian"] = std::move(sum_hessian);
|
||||
out["base_weights"] = std::move(base_weights);
|
||||
out[tf::kLeft] = std::move(lefts);
|
||||
out[tf::kRight] = std::move(rights);
|
||||
out[tf::kParent] = std::move(parents);
|
||||
|
||||
out["left_children"] = std::move(lefts);
|
||||
out["right_children"] = std::move(rights);
|
||||
out["parents"] = std::move(parents);
|
||||
|
||||
out["split_conditions"] = std::move(conds);
|
||||
out["default_left"] = std::move(default_left);
|
||||
out[tf::kSplitCond] = std::move(conds);
|
||||
out[tf::kDftLeft] = std::move(default_left);
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
|
||||
@@ -445,7 +445,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
dh::caching_device_vector<FeatureType> d_split_types;
|
||||
dh::caching_device_vector<uint32_t> d_categories;
|
||||
dh::caching_device_vector<RegTree::Segment> d_categories_segments;
|
||||
dh::caching_device_vector<RegTree::CategoricalSplitMatrix::Segment> d_categories_segments;
|
||||
|
||||
if (!categories.empty()) {
|
||||
dh::CopyToD(h_split_types, &d_split_types);
|
||||
@@ -458,12 +458,11 @@ struct GPUHistMakerDevice {
|
||||
p_out_position);
|
||||
}
|
||||
|
||||
void FinalisePositionInPage(EllpackPageImpl const *page,
|
||||
const common::Span<RegTree::Node> d_nodes,
|
||||
common::Span<FeatureType const> d_feature_types,
|
||||
common::Span<uint32_t const> categories,
|
||||
common::Span<RegTree::Segment> categories_segments,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
void FinalisePositionInPage(
|
||||
EllpackPageImpl const* page, const common::Span<RegTree::Node> d_nodes,
|
||||
common::Span<FeatureType const> d_feature_types, common::Span<uint32_t const> categories,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment> categories_segments,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
auto d_gpair = this->gpair;
|
||||
p_out_position->SetDevice(ctx_->gpu_id);
|
||||
|
||||
Reference in New Issue
Block a user