/** * Copyright 2023 by XGBoost contributors * * \brief Core data structure for multi-target trees. */ #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_ #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_ #include // for bst_node_t, bst_target_t, bst_feature_t #include // for Context #include // for VectorView #include // for Model #include // for Span #include // for uint8_t #include // for size_t #include // for vector namespace xgboost { struct TreeParam; /** * \brief Tree structure for multi-target model. */ class MultiTargetTree : public Model { public: static bst_node_t constexpr InvalidNodeId() { return -1; } private: TreeParam const* param_; std::vector left_; std::vector right_; std::vector parent_; std::vector split_index_; std::vector default_left_; std::vector split_conds_; std::vector weights_; [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) const { auto beg = nidx * this->NumTarget(); auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size()); } [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) { auto beg = nidx * this->NumTarget(); auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size()); } public: explicit MultiTargetTree(TreeParam const* param); /** * \brief Set the weight for a leaf. */ void SetLeaf(bst_node_t nidx, linalg::VectorView weight); /** * \brief Expand a leaf into split node. */ void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView base_weight, linalg::VectorView left_weight, linalg::VectorView right_weight); [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); } [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); } [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); } [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); } [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; } [[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; } [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; } [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const { return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx); } [[nodiscard]] bst_target_t NumTarget() const; [[nodiscard]] std::size_t Size() const; [[nodiscard]] bst_node_t Depth(bst_node_t nidx) const { bst_node_t depth{0}; while (Parent(nidx) != InvalidNodeId()) { ++depth; nidx = Parent(nidx); } return depth; } [[nodiscard]] linalg::VectorView LeafValue(bst_node_t nidx) const { CHECK(IsLeaf(nidx)); return this->NodeWeight(nidx); } void LoadModel(Json const& in) override; void SaveModel(Json* out) const override; }; } // namespace xgboost #endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_