/** * Copyright 2023 by XGBoost Contributors */ #include "xgboost/multi_target_tree_model.h" #include // for copy_n #include // for size_t #include // for int32_t, uint8_t #include // for numeric_limits #include // for string_view #include // for move #include // for vector #include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ... #include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t #include "xgboost/json.h" // for Json, get, Object, Number, Integer, ... #include "xgboost/logging.h" #include "xgboost/tree_model.h" // for TreeParam namespace xgboost { MultiTargetTree::MultiTargetTree(TreeParam const* param) : param_{param}, left_(1ul, InvalidNodeId()), right_(1ul, InvalidNodeId()), parent_(1ul, InvalidNodeId()), split_index_(1ul, 0), default_left_(1ul, 0), split_conds_(1ul, std::numeric_limits::quiet_NaN()), weights_(param->size_leaf_vector, std::numeric_limits::quiet_NaN()) { CHECK_GT(param_->size_leaf_vector, 1); } template void LoadModelImpl(Json const& in, std::vector* p_weights, std::vector* p_lefts, std::vector* p_rights, std::vector* p_parents, std::vector* p_conds, std::vector* p_fidx, std::vector* p_dft_left) { namespace tf = tree_field; auto get_float = [&](std::string_view name, std::vector* p_out) { auto& values = get>(get(in).find(name)->second); auto& out = *p_out; out.resize(values.size()); for (std::size_t i = 0; i < values.size(); ++i) { out[i] = GetElem(values, i); } }; get_float(tf::kBaseWeight, p_weights); get_float(tf::kSplitCond, p_conds); auto get_nidx = [&](std::string_view name, std::vector* p_nidx) { auto& nidx = get>(get(in).find(name)->second); auto& out_nidx = *p_nidx; out_nidx.resize(nidx.size()); for (std::size_t i = 0; i < nidx.size(); ++i) { out_nidx[i] = GetElem(nidx, i); } }; get_nidx(tf::kLeft, p_lefts); get_nidx(tf::kRight, p_rights); get_nidx(tf::kParent, p_parents); auto const& splits = get const>(in[tf::kSplitIdx]); p_fidx->resize(splits.size()); auto& out_fidx = *p_fidx; for (std::size_t i = 0; i < splits.size(); ++i) { out_fidx[i] = GetElem(splits, i); } auto const& dft_left = get const>(in[tf::kDftLeft]); auto& out_dft_l = *p_dft_left; out_dft_l.resize(dft_left.size()); for (std::size_t i = 0; i < dft_left.size(); ++i) { out_dft_l[i] = GetElem(dft_left, i); } } void MultiTargetTree::LoadModel(Json const& in) { namespace tf = tree_field; bool typed = IsA(in[tf::kBaseWeight]); bool feature_is_64 = IsA(in[tf::kSplitIdx]); if (typed && feature_is_64) { LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, &split_index_, &default_left_); } else if (typed && !feature_is_64) { LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, &split_index_, &default_left_); } else if (!typed && feature_is_64) { LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, &split_index_, &default_left_); } else { LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, &split_index_, &default_left_); } } void MultiTargetTree::SaveModel(Json* p_out) const { CHECK(p_out); auto& out = *p_out; auto n_nodes = param_->num_nodes; // nodes I32Array lefts(n_nodes); I32Array rights(n_nodes); I32Array parents(n_nodes); F32Array conds(n_nodes); U8Array default_left(n_nodes); F32Array weights(n_nodes * this->NumTarget()); auto save_tree = [&](auto* p_indices_array) { auto& indices_array = *p_indices_array; for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) { CHECK_LT(nidx, left_.size()); lefts.Set(nidx, left_[nidx]); CHECK_LT(nidx, right_.size()); rights.Set(nidx, right_[nidx]); CHECK_LT(nidx, parent_.size()); parents.Set(nidx, parent_[nidx]); CHECK_LT(nidx, split_index_.size()); indices_array.Set(nidx, split_index_[nidx]); conds.Set(nidx, split_conds_[nidx]); default_left.Set(nidx, default_left_[nidx]); auto in_weight = this->NodeWeight(nidx); auto weight_out = common::Span(weights.GetArray()) .subspan(nidx * this->NumTarget(), this->NumTarget()); CHECK_EQ(in_weight.Size(), weight_out.size()); std::copy_n(in_weight.Values().data(), in_weight.Size(), weight_out.data()); } }; namespace tf = tree_field; if (this->param_->num_feature > static_cast(std::numeric_limits::max())) { I64Array indices_64(n_nodes); save_tree(&indices_64); out[tf::kSplitIdx] = std::move(indices_64); } else { I32Array indices_32(n_nodes); save_tree(&indices_32); out[tf::kSplitIdx] = std::move(indices_32); } out[tf::kBaseWeight] = std::move(weights); out[tf::kLeft] = std::move(lefts); out[tf::kRight] = std::move(rights); out[tf::kParent] = std::move(parents); out[tf::kSplitCond] = std::move(conds); out[tf::kDftLeft] = std::move(default_left); } void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView weight) { CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented(); auto const next_nidx = nidx + 1; CHECK_EQ(weight.Size(), this->NumTarget()); CHECK_GE(weights_.size(), next_nidx * weight.Size()); auto out_weight = common::Span(weights_).subspan(nidx * weight.Size(), weight.Size()); for (std::size_t i = 0; i < weight.Size(); ++i) { out_weight[i] = weight(i); } } void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView base_weight, linalg::VectorView left_weight, linalg::VectorView right_weight) { CHECK(this->IsLeaf(nidx)); CHECK_GE(parent_.size(), 1); CHECK_EQ(parent_.size(), left_.size()); CHECK_EQ(left_.size(), right_.size()); std::size_t n = param_->num_nodes + 2; CHECK_LT(split_idx, this->param_->num_feature); left_.resize(n, InvalidNodeId()); right_.resize(n, InvalidNodeId()); parent_.resize(n, InvalidNodeId()); auto left_child = parent_.size() - 2; auto right_child = parent_.size() - 1; left_[nidx] = left_child; right_[nidx] = right_child; if (nidx != 0) { CHECK_NE(parent_[nidx], InvalidNodeId()); } parent_[left_child] = nidx; parent_[right_child] = nidx; split_index_.resize(n); split_index_[nidx] = split_idx; split_conds_.resize(n, std::numeric_limits::quiet_NaN()); split_conds_[nidx] = split_cond; default_left_.resize(n); default_left_[nidx] = static_cast(default_left); weights_.resize(n * this->NumTarget()); auto p_weight = this->NodeWeight(nidx); CHECK_EQ(p_weight.Size(), base_weight.Size()); auto l_weight = this->NodeWeight(left_child); CHECK_EQ(l_weight.Size(), left_weight.Size()); auto r_weight = this->NodeWeight(right_child); CHECK_EQ(r_weight.Size(), right_weight.Size()); for (std::size_t i = 0; i < base_weight.Size(); ++i) { p_weight(i) = base_weight(i); l_weight(i) = left_weight(i); r_weight(i) = right_weight(i); } } bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; } std::size_t MultiTargetTree::Size() const { return parent_.size(); } } // namespace xgboost