Define core multi-target regression tree structure. (#8884)

- Define a new tree struct embedded in the `RegTree`.
- Provide dispatching functions in `RegTree`.
- Fix some c++-17 warnings about the use of nodiscard (currently we disable the warning on
  the CI).
- Use uint32_t instead of size_t for `bst_target_t` as it has a defined size and can be used
  as part of dmlc parameter.
- Hide the `Segment` struct inside the categorical split matrix.
This commit is contained in:
Jiaming Yuan 2023-03-09 19:03:06 +08:00 committed by GitHub
parent 46dfcc7d22
commit 5feee8d4a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 809 additions and 264 deletions

View File

@ -61,6 +61,7 @@ OBJECTS= \
$(PKGROOT)/src/tree/fit_stump.o \ $(PKGROOT)/src/tree/fit_stump.o \
$(PKGROOT)/src/tree/tree_model.o \ $(PKGROOT)/src/tree/tree_model.o \
$(PKGROOT)/src/tree/tree_updater.o \ $(PKGROOT)/src/tree/tree_updater.o \
$(PKGROOT)/src/tree/multi_target_tree_model.o \
$(PKGROOT)/src/tree/updater_approx.o \ $(PKGROOT)/src/tree/updater_approx.o \
$(PKGROOT)/src/tree/updater_colmaker.o \ $(PKGROOT)/src/tree/updater_colmaker.o \
$(PKGROOT)/src/tree/updater_prune.o \ $(PKGROOT)/src/tree/updater_prune.o \

View File

@ -60,6 +60,7 @@ OBJECTS= \
$(PKGROOT)/src/tree/param.o \ $(PKGROOT)/src/tree/param.o \
$(PKGROOT)/src/tree/fit_stump.o \ $(PKGROOT)/src/tree/fit_stump.o \
$(PKGROOT)/src/tree/tree_model.o \ $(PKGROOT)/src/tree/tree_model.o \
$(PKGROOT)/src/tree/multi_target_tree_model.o \
$(PKGROOT)/src/tree/tree_updater.o \ $(PKGROOT)/src/tree/tree_updater.o \
$(PKGROOT)/src/tree/updater_approx.o \ $(PKGROOT)/src/tree/updater_approx.o \
$(PKGROOT)/src/tree/updater_colmaker.o \ $(PKGROOT)/src/tree/updater_colmaker.o \

View File

@ -110,11 +110,11 @@ using bst_bin_t = int32_t; // NOLINT
*/ */
using bst_row_t = std::size_t; // NOLINT using bst_row_t = std::size_t; // NOLINT
/*! \brief Type for tree node index. */ /*! \brief Type for tree node index. */
using bst_node_t = int32_t; // NOLINT using bst_node_t = std::int32_t; // NOLINT
/*! \brief Type for ranking group index. */ /*! \brief Type for ranking group index. */
using bst_group_t = uint32_t; // NOLINT using bst_group_t = std::uint32_t; // NOLINT
/*! \brief Type for indexing target variables. */ /*! \brief Type for indexing into output targets. */
using bst_target_t = std::size_t; // NOLINT using bst_target_t = std::uint32_t; // NOLINT
namespace detail { namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation /*! \brief Implementation of gradient statistics pair. Template specialisation

View File

@ -0,0 +1,96 @@
/**
* Copyright 2023 by XGBoost contributors
*
* \brief Core data structure for multi-target trees.
*/
#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Model
#include <xgboost/span.h> // for Span
#include <cinttypes> // for uint8_t
#include <cstddef> // for size_t
#include <vector> // for vector
namespace xgboost {
struct TreeParam;
/**
* \brief Tree structure for multi-target model.
*/
class MultiTargetTree : public Model {
public:
static bst_node_t constexpr InvalidNodeId() { return -1; }
private:
TreeParam const* param_;
std::vector<bst_node_t> left_;
std::vector<bst_node_t> right_;
std::vector<bst_node_t> parent_;
std::vector<bst_feature_t> split_index_;
std::vector<std::uint8_t> default_left_;
std::vector<float> split_conds_;
std::vector<float> weights_;
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
}
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
}
public:
explicit MultiTargetTree(TreeParam const* param);
/**
* \brief Set the weight for a leaf.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
/**
* \brief Expand a leaf into split node.
*/
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
[[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
}
[[nodiscard]] bst_target_t NumTarget() const;
[[nodiscard]] std::size_t Size() const;
[[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
bst_node_t depth{0};
while (Parent(nidx) != InvalidNodeId()) {
++depth;
nidx = Parent(nidx);
}
return depth;
}
[[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
CHECK(IsLeaf(nidx));
return this->NodeWeight(nidx);
}
void LoadModel(Json const& in) override;
void SaveModel(Json* out) const override;
};
} // namespace xgboost
#endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2014-2022 by Contributors * Copyright 2014-2023 by Contributors
* \file tree_model.h * \file tree_model.h
* \brief model structure for tree * \brief model structure for tree
* \author Tianqi Chen * \author Tianqi Chen
@ -9,60 +9,57 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/feature_map.h> #include <xgboost/feature_map.h>
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/logging.h>
#include <xgboost/model.h> #include <xgboost/model.h>
#include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
#include <limits>
#include <vector>
#include <string>
#include <cstring>
#include <algorithm> #include <algorithm>
#include <tuple> #include <cstring>
#include <limits>
#include <memory> // for make_unique
#include <stack> #include <stack>
#include <string>
#include <tuple>
#include <vector>
namespace xgboost { namespace xgboost {
struct PathElement; // forward declaration
class Json; class Json;
// FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
// not be configured by users. // not be configured by users.
/*! \brief meta parameters of the tree */ /*! \brief meta parameters of the tree */
struct TreeParam : public dmlc::Parameter<TreeParam> { struct TreeParam : public dmlc::Parameter<TreeParam> {
/*! \brief (Deprecated) number of start root */ /*! \brief (Deprecated) number of start root */
int deprecated_num_roots; int deprecated_num_roots{1};
/*! \brief total number of nodes */ /*! \brief total number of nodes */
int num_nodes; int num_nodes{1};
/*!\brief number of deleted nodes */ /*!\brief number of deleted nodes */
int num_deleted; int num_deleted{0};
/*! \brief maximum depth, this is a statistics of the tree */ /*! \brief maximum depth, this is a statistics of the tree */
int deprecated_max_depth; int deprecated_max_depth{0};
/*! \brief number of features used for tree construction */ /*! \brief number of features used for tree construction */
bst_feature_t num_feature; bst_feature_t num_feature{0};
/*! /*!
* \brief leaf vector size, used for vector tree * \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree * used to store more than one dimensional information in tree
*/ */
int size_leaf_vector; bst_target_t size_leaf_vector{1};
/*! \brief reserved part, make sure alignment works for 64bit */ /*! \brief reserved part, make sure alignment works for 64bit */
int reserved[31]; int reserved[31];
/*! \brief constructor */ /*! \brief constructor */
TreeParam() { TreeParam() {
// assert compact alignment // assert compact alignment
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align");
"TreeParam: 64 bit align"); std::memset(reserved, 0, sizeof(reserved));
std::memset(this, 0, sizeof(TreeParam));
num_nodes = 1;
deprecated_num_roots = 1;
} }
// Swap byte order for all fields. Useful for transporting models between machines with different // Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian) // endianness (big endian vs little endian)
inline TreeParam ByteSwap() const { [[nodiscard]] TreeParam ByteSwap() const {
TreeParam x = *this; TreeParam x = *this;
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1); dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1); dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
@ -80,17 +77,18 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
// other arguments are set by the algorithm. // other arguments are set by the algorithm.
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1); DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
DMLC_DECLARE_FIELD(num_feature) DMLC_DECLARE_FIELD(num_feature)
.set_default(0)
.describe("Number of features used in tree construction."); .describe("Number of features used in tree construction.");
DMLC_DECLARE_FIELD(num_deleted); DMLC_DECLARE_FIELD(num_deleted).set_default(0);
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) DMLC_DECLARE_FIELD(size_leaf_vector)
.set_lower_bound(0)
.set_default(1)
.describe("Size of leaf vector, reserved for vector tree"); .describe("Size of leaf vector, reserved for vector tree");
} }
bool operator==(const TreeParam& b) const { bool operator==(const TreeParam& b) const {
return num_nodes == b.num_nodes && return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
num_deleted == b.num_deleted && num_feature == b.num_feature && size_leaf_vector == b.size_leaf_vector;
num_feature == b.num_feature &&
size_leaf_vector == b.size_leaf_vector;
} }
}; };
@ -114,7 +112,7 @@ struct RTreeNodeStat {
} }
// Swap byte order for all fields. Useful for transporting models between machines with different // Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian) // endianness (big endian vs little endian)
inline RTreeNodeStat ByteSwap() const { [[nodiscard]] RTreeNodeStat ByteSwap() const {
RTreeNodeStat x = *this; RTreeNodeStat x = *this;
dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1); dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1); dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
@ -124,14 +122,43 @@ struct RTreeNodeStat {
} }
}; };
/*! /**
* \brief Helper for defining copyable data structure that contains unique pointers.
*/
template <typename T>
class CopyUniquePtr {
std::unique_ptr<T> ptr_{nullptr};
public:
CopyUniquePtr() = default;
CopyUniquePtr(CopyUniquePtr const& that) {
ptr_.reset(nullptr);
if (that.ptr_) {
ptr_ = std::make_unique<T>(*that);
}
}
T* get() const noexcept { return ptr_.get(); } // NOLINT
T& operator*() { return *ptr_; }
T* operator->() noexcept { return this->get(); }
T const& operator*() const { return *ptr_; }
T const* operator->() const noexcept { return this->get(); }
explicit operator bool() const { return static_cast<bool>(ptr_); }
bool operator!() const { return !ptr_; }
void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
};
/**
* \brief define regression tree to be the most common tree model. * \brief define regression tree to be the most common tree model.
*
* This is the data structure used in xgboost's major tree models. * This is the data structure used in xgboost's major tree models.
*/ */
class RegTree : public Model { class RegTree : public Model {
public: public:
using SplitCondT = bst_float; using SplitCondT = bst_float;
static constexpr bst_node_t kInvalidNodeId {-1}; static constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()};
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max(); static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
static constexpr bst_node_t kRoot{0}; static constexpr bst_node_t kRoot{0};
@ -151,51 +178,51 @@ class RegTree : public Model {
} }
/*! \brief index of left child */ /*! \brief index of left child */
XGBOOST_DEVICE int LeftChild() const { XGBOOST_DEVICE [[nodiscard]] int LeftChild() const {
return this->cleft_; return this->cleft_;
} }
/*! \brief index of right child */ /*! \brief index of right child */
XGBOOST_DEVICE int RightChild() const { XGBOOST_DEVICE [[nodiscard]] int RightChild() const {
return this->cright_; return this->cright_;
} }
/*! \brief index of default child when feature is missing */ /*! \brief index of default child when feature is missing */
XGBOOST_DEVICE int DefaultChild() const { XGBOOST_DEVICE [[nodiscard]] int DefaultChild() const {
return this->DefaultLeft() ? this->LeftChild() : this->RightChild(); return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
} }
/*! \brief feature index of split condition */ /*! \brief feature index of split condition */
XGBOOST_DEVICE unsigned SplitIndex() const { XGBOOST_DEVICE [[nodiscard]] unsigned SplitIndex() const {
return sindex_ & ((1U << 31) - 1U); return sindex_ & ((1U << 31) - 1U);
} }
/*! \brief when feature is unknown, whether goes to left child */ /*! \brief when feature is unknown, whether goes to left child */
XGBOOST_DEVICE bool DefaultLeft() const { XGBOOST_DEVICE [[nodiscard]] bool DefaultLeft() const {
return (sindex_ >> 31) != 0; return (sindex_ >> 31) != 0;
} }
/*! \brief whether current node is leaf node */ /*! \brief whether current node is leaf node */
XGBOOST_DEVICE bool IsLeaf() const { XGBOOST_DEVICE [[nodiscard]] bool IsLeaf() const {
return cleft_ == kInvalidNodeId; return cleft_ == kInvalidNodeId;
} }
/*! \return get leaf value of leaf node */ /*! \return get leaf value of leaf node */
XGBOOST_DEVICE bst_float LeafValue() const { XGBOOST_DEVICE [[nodiscard]] float LeafValue() const {
return (this->info_).leaf_value; return (this->info_).leaf_value;
} }
/*! \return get split condition of the node */ /*! \return get split condition of the node */
XGBOOST_DEVICE SplitCondT SplitCond() const { XGBOOST_DEVICE [[nodiscard]] SplitCondT SplitCond() const {
return (this->info_).split_cond; return (this->info_).split_cond;
} }
/*! \brief get parent of the node */ /*! \brief get parent of the node */
XGBOOST_DEVICE int Parent() const { XGBOOST_DEVICE [[nodiscard]] int Parent() const {
return parent_ & ((1U << 31) - 1); return parent_ & ((1U << 31) - 1);
} }
/*! \brief whether current node is left child */ /*! \brief whether current node is left child */
XGBOOST_DEVICE bool IsLeftChild() const { XGBOOST_DEVICE [[nodiscard]] bool IsLeftChild() const {
return (parent_ & (1U << 31)) != 0; return (parent_ & (1U << 31)) != 0;
} }
/*! \brief whether this node is deleted */ /*! \brief whether this node is deleted */
XGBOOST_DEVICE bool IsDeleted() const { XGBOOST_DEVICE [[nodiscard]] bool IsDeleted() const {
return sindex_ == kDeletedNodeMarker; return sindex_ == kDeletedNodeMarker;
} }
/*! \brief whether current node is root */ /*! \brief whether current node is root */
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } XGBOOST_DEVICE [[nodiscard]] bool IsRoot() const { return parent_ == kInvalidNodeId; }
/*! /*!
* \brief set the left child * \brief set the left child
* \param nid node id to right child * \param nid node id to right child
@ -252,7 +279,7 @@ class RegTree : public Model {
info_.leaf_value == b.info_.leaf_value; info_.leaf_value == b.info_.leaf_value;
} }
inline Node ByteSwap() const { [[nodiscard]] Node ByteSwap() const {
Node x = *this; Node x = *this;
dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1); dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1); dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
@ -312,10 +339,8 @@ class RegTree : public Model {
/*! \brief model parameter */ /*! \brief model parameter */
TreeParam param; TreeParam param;
/*! \brief constructor */
RegTree() { RegTree() {
param.num_nodes = 1; param.Init(Args{});
param.num_deleted = 0;
nodes_.resize(param.num_nodes); nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes); stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical); split_types_.resize(param.num_nodes, FeatureType::kNumerical);
@ -325,6 +350,17 @@ class RegTree : public Model {
nodes_[i].SetParent(kInvalidNodeId); nodes_[i].SetParent(kInvalidNodeId);
} }
} }
/**
* \brief Constructor that initializes the tree model with shape.
*/
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
param.num_feature = n_features;
param.size_leaf_vector = n_targets;
if (n_targets > 1) {
this->p_mt_tree_.reset(new MultiTargetTree{&param});
}
}
/*! \brief get node given nid */ /*! \brief get node given nid */
Node& operator[](int nid) { Node& operator[](int nid) {
return nodes_[nid]; return nodes_[nid];
@ -335,17 +371,17 @@ class RegTree : public Model {
} }
/*! \brief get const reference to nodes */ /*! \brief get const reference to nodes */
const std::vector<Node>& GetNodes() const { return nodes_; } [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
/*! \brief get const reference to stats */ /*! \brief get const reference to stats */
const std::vector<RTreeNodeStat>& GetStats() const { return stats_; } [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
/*! \brief get node statistics given nid */ /*! \brief get node statistics given nid */
RTreeNodeStat& Stat(int nid) { RTreeNodeStat& Stat(int nid) {
return stats_[nid]; return stats_[nid];
} }
/*! \brief get node statistics given nid */ /*! \brief get node statistics given nid */
const RTreeNodeStat& Stat(int nid) const { [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
return stats_[nid]; return stats_[nid];
} }
@ -398,7 +434,7 @@ class RegTree : public Model {
* *
* \param b The other tree. * \param b The other tree.
*/ */
bool Equal(const RegTree& b) const; [[nodiscard]] bool Equal(const RegTree& b) const;
/** /**
* \brief Expands a leaf node into two additional leaf nodes. * \brief Expands a leaf node into two additional leaf nodes.
@ -424,6 +460,11 @@ class RegTree : public Model {
float right_sum, float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId); bst_node_t leaf_right_child = kInvalidNodeId);
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
/** /**
* \brief Expands a leaf node with categories * \brief Expands a leaf node with categories
* *
@ -445,15 +486,27 @@ class RegTree : public Model {
bst_float right_leaf_weight, bst_float loss_change, float sum_hess, bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum); float left_sum, float right_sum);
bool HasCategoricalSplit() const { [[nodiscard]] bool HasCategoricalSplit() const {
return !split_categories_.empty(); return !split_categories_.empty();
} }
/**
* \brief Whether this is a multi-target tree.
*/
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
[[nodiscard]] bst_target_t NumTargets() const { return param.size_leaf_vector; }
[[nodiscard]] auto GetMultiTargetTree() const {
CHECK(IsMultiTarget());
return p_mt_tree_.get();
}
/*! /*!
* \brief get current depth * \brief get current depth
* \param nid node id * \param nid node id
*/ */
int GetDepth(int nid) const { [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->Depth(nid);
}
int depth = 0; int depth = 0;
while (!nodes_[nid].IsRoot()) { while (!nodes_[nid].IsRoot()) {
++depth; ++depth;
@ -461,12 +514,16 @@ class RegTree : public Model {
} }
return depth; return depth;
} }
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
CHECK(IsMultiTarget());
return this->p_mt_tree_->SetLeaf(nidx, weight);
}
/*! /*!
* \brief get maximum depth * \brief get maximum depth
* \param nid node id * \param nid node id
*/ */
int MaxDepth(int nid) const { [[nodiscard]] int MaxDepth(int nid) const {
if (nodes_[nid].IsLeaf()) return 0; if (nodes_[nid].IsLeaf()) return 0;
return std::max(MaxDepth(nodes_[nid].LeftChild())+1, return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
MaxDepth(nodes_[nid].RightChild())+1); MaxDepth(nodes_[nid].RightChild())+1);
@ -480,13 +537,13 @@ class RegTree : public Model {
} }
/*! \brief number of extra nodes besides the root */ /*! \brief number of extra nodes besides the root */
int NumExtraNodes() const { [[nodiscard]] int NumExtraNodes() const {
return param.num_nodes - 1 - param.num_deleted; return param.num_nodes - 1 - param.num_deleted;
} }
/* \brief Count number of leaves in tree. */ /* \brief Count number of leaves in tree. */
bst_node_t GetNumLeaves() const; [[nodiscard]] bst_node_t GetNumLeaves() const;
bst_node_t GetNumSplitNodes() const; [[nodiscard]] bst_node_t GetNumSplitNodes() const;
/*! /*!
* \brief dense feature vector that can be taken by RegTree * \brief dense feature vector that can be taken by RegTree
@ -513,20 +570,20 @@ class RegTree : public Model {
* \brief returns the size of the feature vector * \brief returns the size of the feature vector
* \return the size of the feature vector * \return the size of the feature vector
*/ */
size_t Size() const; [[nodiscard]] size_t Size() const;
/*! /*!
* \brief get ith value * \brief get ith value
* \param i feature index. * \param i feature index.
* \return the i-th feature value * \return the i-th feature value
*/ */
bst_float GetFvalue(size_t i) const; [[nodiscard]] bst_float GetFvalue(size_t i) const;
/*! /*!
* \brief check whether i-th entry is missing * \brief check whether i-th entry is missing
* \param i feature index. * \param i feature index.
* \return whether i-th value is missing. * \return whether i-th value is missing.
*/ */
bool IsMissing(size_t i) const; [[nodiscard]] bool IsMissing(size_t i) const;
bool HasMissing() const; [[nodiscard]] bool HasMissing() const;
private: private:
@ -557,56 +614,123 @@ class RegTree : public Model {
* \param format the format to dump the model in * \param format the format to dump the model in
* \return the string of dumped model * \return the string of dumped model
*/ */
std::string DumpModel(const FeatureMap& fmap, [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
bool with_stats,
std::string format) const; std::string format) const;
/*! /*!
* \brief Get split type for a node. * \brief Get split type for a node.
* \param nidx Index of node. * \param nidx Index of node.
* \return The type of this split. For leaf node it's always kNumerical. * \return The type of this split. For leaf node it's always kNumerical.
*/ */
FeatureType NodeSplitType(bst_node_t nidx) const { [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
return split_types_.at(nidx);
}
/*! /*!
* \brief Get split types for all nodes. * \brief Get split types for all nodes.
*/ */
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; } [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; } return split_types_;
}
[[nodiscard]] common::Span<uint32_t const> GetSplitCategories() const {
return split_categories_;
}
/*! /*!
* \brief Get the bit storage for categories * \brief Get the bit storage for categories
*/ */
common::Span<uint32_t const> NodeCats(bst_node_t nidx) const { [[nodiscard]] common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
auto node_ptr = GetCategoriesMatrix().node_ptr; auto node_ptr = GetCategoriesMatrix().node_ptr;
auto categories = GetCategoriesMatrix().categories; auto categories = GetCategoriesMatrix().categories;
auto segment = node_ptr[nidx]; auto segment = node_ptr[nidx];
auto node_cats = categories.subspan(segment.beg, segment.size); auto node_cats = categories.subspan(segment.beg, segment.size);
return node_cats; return node_cats;
} }
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; } [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
// The fields of split_categories_segments_[i] are set such that
// the range split_categories_[beg:(beg+size)] stores the bitset for
// the matching categories for the i-th node.
struct Segment {
size_t beg {0};
size_t size {0};
};
/**
* \brief CSR-like matrix for categorical splits.
*
* The fields of split_categories_segments_[i] are set such that the range
* node_ptr[beg:(beg+size)] stores the bitset for the matching categories for the
* i-th node.
*/
struct CategoricalSplitMatrix { struct CategoricalSplitMatrix {
struct Segment {
std::size_t beg{0};
std::size_t size{0};
};
common::Span<FeatureType const> split_type; common::Span<FeatureType const> split_type;
common::Span<uint32_t const> categories; common::Span<uint32_t const> categories;
common::Span<Segment const> node_ptr; common::Span<Segment const> node_ptr;
}; };
CategoricalSplitMatrix GetCategoriesMatrix() const { [[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix() const {
CategoricalSplitMatrix view; CategoricalSplitMatrix view;
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes()); view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
view.categories = this->GetSplitCategories(); view.categories = this->GetSplitCategories();
view.node_ptr = common::Span<Segment const>(split_categories_segments_); view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
return view; return view;
} }
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->SplitIndex(nidx);
}
return (*this)[nidx].SplitIndex();
}
[[nodiscard]] float SplitCond(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->SplitCond(nidx);
}
return (*this)[nidx].SplitCond();
}
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->DefaultLeft(nidx);
}
return (*this)[nidx].DefaultLeft();
}
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
if (IsMultiTarget()) {
return nidx == kRoot;
}
return (*this)[nidx].IsRoot();
}
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->IsLeaf(nidx);
}
return (*this)[nidx].IsLeaf();
}
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->Parent(nidx);
}
return (*this)[nidx].Parent();
}
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->LeftChild(nidx);
}
return (*this)[nidx].LeftChild();
}
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
if (IsMultiTarget()) {
return this->p_mt_tree_->RightChild(nidx);
}
return (*this)[nidx].RightChild();
}
[[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
if (IsMultiTarget()) {
CHECK_NE(nidx, kRoot);
auto p = this->p_mt_tree_->Parent(nidx);
return nidx == this->p_mt_tree_->LeftChild(p);
}
return (*this)[nidx].IsLeftChild();
}
[[nodiscard]] bst_node_t Size() const {
if (IsMultiTarget()) {
return this->p_mt_tree_->Size();
}
return this->nodes_.size();
}
private: private:
template <bool typed> template <bool typed>
void LoadCategoricalSplit(Json const& in); void LoadCategoricalSplit(Json const& in);
@ -622,8 +746,9 @@ class RegTree : public Model {
// Categories for each internal node. // Categories for each internal node.
std::vector<uint32_t> split_categories_; std::vector<uint32_t> split_categories_;
// Ptr to split categories of each node. // Ptr to split categories of each node.
std::vector<Segment> split_categories_segments_; std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
// ptr to multi-target tree with vector leaf.
CopyUniquePtr<MultiTargetTree> p_mt_tree_;
// allocate a new node, // allocate a new node,
// !!!!!! NOTE: may cause BUG here, nodes.resize // !!!!!! NOTE: may cause BUG here, nodes.resize
bst_node_t AllocNode() { bst_node_t AllocNode() {
@ -703,5 +828,10 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
inline bool RegTree::FVec::HasMissing() const { inline bool RegTree::FVec::HasMissing() const {
return has_missing_; return has_missing_;
} }
// Multi-target tree not yet implemented error
inline StringView MTNotImplemented() {
return " support for multi-target tree is not yet implemented.";
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_ #endif // XGBOOST_TREE_MODEL_H_

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2017 XGBoost contributors * Copyright 2017-2023 by XGBoost contributors
*/ */
#ifndef XGBOOST_USE_CUDA #ifndef XGBOOST_USE_CUDA
@ -179,7 +179,6 @@ template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>; template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Segment>;
#if defined(__APPLE__) || defined(__EMSCRIPTEN__) #if defined(__APPLE__) || defined(__EMSCRIPTEN__)
/* /*

View File

@ -1,7 +1,6 @@
/*! /**
* Copyright 2017 XGBoost contributors * Copyright 2017-2023 by XGBoost contributors
*/ */
#include <thrust/fill.h> #include <thrust/fill.h>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
@ -412,7 +411,7 @@ template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>; template class HostDeviceVector<RegTree::Node>;
template class HostDeviceVector<RegTree::Segment>; template class HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment>;
template class HostDeviceVector<RTreeNodeStat>; template class HostDeviceVector<RTreeNodeStat>;
#if defined(__APPLE__) #if defined(__APPLE__)

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2017-2021 by Contributors * Copyright 2017-2023 by XGBoost Contributors
*/ */
#include <GPUTreeShap/gpu_treeshap.h> #include <GPUTreeShap/gpu_treeshap.h>
#include <thrust/copy.h> #include <thrust/copy.h>
@ -25,9 +25,7 @@
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
namespace xgboost { namespace xgboost::predictor {
namespace predictor {
DMLC_REGISTRY_FILE_TAG(gpu_predictor); DMLC_REGISTRY_FILE_TAG(gpu_predictor);
struct TreeView { struct TreeView {
@ -35,12 +33,11 @@ struct TreeView {
common::Span<RegTree::Node const> d_tree; common::Span<RegTree::Node const> d_tree;
XGBOOST_DEVICE XGBOOST_DEVICE
TreeView(size_t tree_begin, size_t tree_idx, TreeView(size_t tree_begin, size_t tree_idx, common::Span<const RegTree::Node> d_nodes,
common::Span<const RegTree::Node> d_nodes,
common::Span<size_t const> d_tree_segments, common::Span<size_t const> d_tree_segments,
common::Span<FeatureType const> d_tree_split_types, common::Span<FeatureType const> d_tree_split_types,
common::Span<uint32_t const> d_cat_tree_segments, common::Span<uint32_t const> d_cat_tree_segments,
common::Span<RegTree::Segment const> d_cat_node_segments, common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories) { common::Span<uint32_t const> d_categories) {
auto begin = d_tree_segments[tree_idx - tree_begin]; auto begin = d_tree_segments[tree_idx - tree_begin];
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] - auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
@ -255,7 +252,7 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<FeatureType const> d_tree_split_types, common::Span<FeatureType const> d_tree_split_types,
common::Span<uint32_t const> d_cat_tree_segments, common::Span<uint32_t const> d_cat_tree_segments,
common::Span<RegTree::Segment const> d_cat_node_segments, common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories, common::Span<uint32_t const> d_categories,
size_t tree_begin, size_t tree_end, size_t num_features, size_t tree_begin, size_t tree_end, size_t num_features,
@ -290,7 +287,7 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<int const> d_tree_group, common::Span<int const> d_tree_group,
common::Span<FeatureType const> d_tree_split_types, common::Span<FeatureType const> d_tree_split_types,
common::Span<uint32_t const> d_cat_tree_segments, common::Span<uint32_t const> d_cat_tree_segments,
common::Span<RegTree::Segment const> d_cat_node_segments, common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories, size_t tree_begin, common::Span<uint32_t const> d_categories, size_t tree_begin,
size_t tree_end, size_t num_features, size_t num_rows, size_t tree_end, size_t num_features, size_t num_rows,
size_t entry_start, bool use_shared, int num_group, float missing) { size_t entry_start, bool use_shared, int num_group, float missing) {
@ -334,7 +331,7 @@ class DeviceModel {
// Pointer to each tree, segmenting the node array. // Pointer to each tree, segmenting the node array.
HostDeviceVector<uint32_t> categories_tree_segments; HostDeviceVector<uint32_t> categories_tree_segments;
// Pointer to each node, segmenting categories array. // Pointer to each node, segmenting categories array.
HostDeviceVector<RegTree::Segment> categories_node_segments; HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment> categories_node_segments;
HostDeviceVector<uint32_t> categories; HostDeviceVector<uint32_t> categories;
size_t tree_beg_; // NOLINT size_t tree_beg_; // NOLINT
@ -400,9 +397,9 @@ class DeviceModel {
h_split_cat_segments.push_back(h_categories.size()); h_split_cat_segments.push_back(h_categories.size());
} }
categories_node_segments = categories_node_segments = HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment>(
HostDeviceVector<RegTree::Segment>(h_tree_segments.back(), {}, gpu_id); h_tree_segments.back(), {}, gpu_id);
std::vector<RegTree::Segment> &h_categories_node_segments = std::vector<RegTree::CategoricalSplitMatrix::Segment>& h_categories_node_segments =
categories_node_segments.HostVector(); categories_node_segments.HostVector();
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr(); auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr();
@ -542,10 +539,10 @@ void ExtractPaths(
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types), if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
common::IsCatOp{})) { common::IsCatOp{})) {
dh::PinnedMemory pinned; dh::PinnedMemory pinned;
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1); auto h_max_cat = pinned.GetSpan<RegTree::CategoricalSplitMatrix::Segment>(1);
auto max_elem_it = dh::MakeTransformIterator<size_t>( auto max_elem_it = dh::MakeTransformIterator<size_t>(
dh::tbegin(d_cat_node_segments), dh::tbegin(d_cat_node_segments),
[] __device__(RegTree::Segment seg) { return seg.size; }); [] __device__(RegTree::CategoricalSplitMatrix::Segment seg) { return seg.size; });
size_t max_cat_it = size_t max_cat_it =
thrust::max_element(thrust::device, max_elem_it, thrust::max_element(thrust::device, max_elem_it,
max_elem_it + d_cat_node_segments.size()) - max_elem_it + d_cat_node_segments.size()) -
@ -1028,5 +1025,4 @@ XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
.describe("Make predictions using GPU.") .describe("Make predictions using GPU.")
.set_body([](Context const* ctx) { return new GPUPredictor(ctx); }); .set_body([](Context const* ctx) { return new GPUPredictor(ctx); });
} // namespace predictor } // namespace xgboost::predictor
} // namespace xgboost

View File

@ -71,10 +71,7 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
auto n_samples = gpair.Size() / n_targets; auto n_samples = gpair.Size() / n_targets;
gpair.SetDevice(ctx->gpu_id); gpair.SetDevice(ctx->gpu_id);
linalg::TensorView<GradientPair const, 2> gpair_t{ auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(),
{n_samples, n_targets},
ctx->gpu_id};
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView()) ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
} }

View File

@ -12,7 +12,7 @@
#include "../../common/hist_util.h" #include "../../common/hist_util.h"
#include "../../data/gradient_index.h" #include "../../data/gradient_index.h"
#include "expand_entry.h" #include "expand_entry.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h" // for RegTree
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -175,8 +175,8 @@ class HistogramBuilder {
auto this_local = hist_local_worker_[entry.nid]; auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end()); common::CopyHist(this_local, this_hist, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot()) { if (!p_tree->IsRoot(entry.nid)) {
const size_t parent_id = (*p_tree)[entry.nid].Parent(); const size_t parent_id = p_tree->Parent(entry.nid);
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid; const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id]; auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id]; auto sibling_hist = this->hist_[subtraction_node_id];
@ -213,8 +213,8 @@ class HistogramBuilder {
// Merging histograms from each thread into once // Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end()); this->buffer_.ReduceHist(node, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot()) { if (!p_tree->IsRoot(entry.nid)) {
auto const parent_id = (*p_tree)[entry.nid].Parent(); auto const parent_id = p_tree->Parent(entry.nid);
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid; auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_[parent_id]; auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id]; auto sibling_hist = this->hist_[subtraction_node_id];
@ -237,10 +237,10 @@ class HistogramBuilder {
common::ParallelFor2d( common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) { space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes[node]; const auto &entry = nodes[node];
if (!((*p_tree)[entry.nid].IsLeftChild())) { if (!(p_tree->IsLeftChild(entry.nid))) {
auto this_hist = this->hist_[entry.nid]; auto this_hist = this->hist_[entry.nid];
if (!(*p_tree)[entry.nid].IsRoot()) { if (!p_tree->IsRoot(entry.nid)) {
const int subtraction_node_id = subtraction_nodes[node].nid; const int subtraction_node_id = subtraction_nodes[node].nid;
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()]; auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[subtraction_node_id]; auto sibling_hist = hist_[subtraction_node_id];
@ -285,7 +285,7 @@ class HistogramBuilder {
std::sort(merged_node_ids.begin(), merged_node_ids.end()); std::sort(merged_node_ids.begin(), merged_node_ids.end());
int n_left = 0; int n_left = 0;
for (auto const &nid : merged_node_ids) { for (auto const &nid : merged_node_ids) {
if ((*p_tree)[nid].IsLeftChild()) { if (p_tree->IsLeftChild(nid)) {
this->hist_.AddHistRow(nid); this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index)); (*starting_index) = std::min(nid, (*starting_index));
n_left++; n_left++;
@ -293,7 +293,7 @@ class HistogramBuilder {
} }
} }
for (auto const &nid : merged_node_ids) { for (auto const &nid : merged_node_ids) {
if (!((*p_tree)[nid].IsLeftChild())) { if (!(p_tree->IsLeftChild(nid))) {
this->hist_.AddHistRow(nid); this->hist_.AddHistRow(nid);
this->hist_local_worker_.AddHistRow(nid); this->hist_local_worker_.AddHistRow(nid);
} }

65
src/tree/io_utils.h Normal file
View File

@ -0,0 +1,65 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_IO_UTILS_H_
#define XGBOOST_TREE_IO_UTILS_H_
#include <string> // for string
#include <type_traits> // for enable_if_t, is_same, conditional_t
#include <vector> // for vector
#include "xgboost/json.h" // for Json
namespace xgboost {
template <bool typed>
using FloatArrayT = std::conditional_t<typed, F32Array const, Array const>;
template <bool typed>
using U8ArrayT = std::conditional_t<typed, U8Array const, Array const>;
template <bool typed>
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>;
template <bool typed>
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>;
template <bool typed, bool feature_is_64>
using IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT<typed>, I32ArrayT<typed>>;
// typed array, not boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
std::vector<T> const& arr, size_t i) {
return arr[i];
}
// typed array boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
std::is_same<JT, Boolean>::value,
bool>
GetElem(std::vector<T> const& arr, size_t i) {
return arr[i] == 1;
}
// json array
template <typename JT, typename T>
std::enable_if_t<
std::is_same<T, Json>::value,
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
GetElem(std::vector<T> const& arr, size_t i) {
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
return get<Integer const>(arr[i]) == 1;
}
return get<JT const>(arr[i]);
}
namespace tree_field {
inline std::string const kLossChg{"loss_changes"};
inline std::string const kSumHess{"sum_hessian"};
inline std::string const kBaseWeight{"base_weights"};
inline std::string const kSplitIdx{"split_indices"};
inline std::string const kSplitCond{"split_conditions"};
inline std::string const kDftLeft{"default_left"};
inline std::string const kParent{"parents"};
inline std::string const kLeft{"left_children"};
inline std::string const kRight{"right_children"};
} // namespace tree_field
} // namespace xgboost
#endif // XGBOOST_TREE_IO_UTILS_H_

View File

@ -0,0 +1,220 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include "xgboost/multi_target_tree_model.h"
#include <algorithm> // for copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint8_t
#include <limits> // for numeric_limits
#include <string_view> // for string_view
#include <utility> // for move
#include <vector> // for vector
#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ...
#include "xgboost/logging.h"
#include "xgboost/tree_model.h" // for TreeParam
namespace xgboost {
MultiTargetTree::MultiTargetTree(TreeParam const* param)
: param_{param},
left_(1ul, InvalidNodeId()),
right_(1ul, InvalidNodeId()),
parent_(1ul, InvalidNodeId()),
split_index_(1ul, 0),
default_left_(1ul, 0),
split_conds_(1ul, std::numeric_limits<float>::quiet_NaN()),
weights_(param->size_leaf_vector, std::numeric_limits<float>::quiet_NaN()) {
CHECK_GT(param_->size_leaf_vector, 1);
}
template <bool typed, bool feature_is_64>
void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bst_node_t>* p_lefts,
std::vector<bst_node_t>* p_rights, std::vector<bst_node_t>* p_parents,
std::vector<float>* p_conds, std::vector<bst_feature_t>* p_fidx,
std::vector<std::uint8_t>* p_dft_left) {
namespace tf = tree_field;
auto get_float = [&](std::string_view name, std::vector<float>* p_out) {
auto& values = get<FloatArrayT<typed>>(get<Object const>(in).find(name)->second);
auto& out = *p_out;
out.resize(values.size());
for (std::size_t i = 0; i < values.size(); ++i) {
out[i] = GetElem<Number>(values, i);
}
};
get_float(tf::kBaseWeight, p_weights);
get_float(tf::kSplitCond, p_conds);
auto get_nidx = [&](std::string_view name, std::vector<bst_node_t>* p_nidx) {
auto& nidx = get<I32ArrayT<typed>>(get<Object const>(in).find(name)->second);
auto& out_nidx = *p_nidx;
out_nidx.resize(nidx.size());
for (std::size_t i = 0; i < nidx.size(); ++i) {
out_nidx[i] = GetElem<Integer>(nidx, i);
}
};
get_nidx(tf::kLeft, p_lefts);
get_nidx(tf::kRight, p_rights);
get_nidx(tf::kParent, p_parents);
auto const& splits = get<IndexArrayT<typed, feature_is_64> const>(in[tf::kSplitIdx]);
p_fidx->resize(splits.size());
auto& out_fidx = *p_fidx;
for (std::size_t i = 0; i < splits.size(); ++i) {
out_fidx[i] = GetElem<Integer>(splits, i);
}
auto const& dft_left = get<U8ArrayT<typed> const>(in[tf::kDftLeft]);
auto& out_dft_l = *p_dft_left;
out_dft_l.resize(dft_left.size());
for (std::size_t i = 0; i < dft_left.size(); ++i) {
out_dft_l[i] = GetElem<Boolean>(dft_left, i);
}
}
void MultiTargetTree::LoadModel(Json const& in) {
namespace tf = tree_field;
bool typed = IsA<F32Array>(in[tf::kBaseWeight]);
bool feature_is_64 = IsA<I64Array>(in[tf::kSplitIdx]);
if (typed && feature_is_64) {
LoadModelImpl<true, true>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
&split_index_, &default_left_);
} else if (typed && !feature_is_64) {
LoadModelImpl<true, false>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
&split_index_, &default_left_);
} else if (!typed && feature_is_64) {
LoadModelImpl<false, true>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
&split_index_, &default_left_);
} else {
LoadModelImpl<false, false>(in, &weights_, &left_, &right_, &parent_, &split_conds_,
&split_index_, &default_left_);
}
}
void MultiTargetTree::SaveModel(Json* p_out) const {
CHECK(p_out);
auto& out = *p_out;
auto n_nodes = param_->num_nodes;
// nodes
I32Array lefts(n_nodes);
I32Array rights(n_nodes);
I32Array parents(n_nodes);
F32Array conds(n_nodes);
U8Array default_left(n_nodes);
F32Array weights(n_nodes * this->NumTarget());
auto save_tree = [&](auto* p_indices_array) {
auto& indices_array = *p_indices_array;
for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) {
CHECK_LT(nidx, left_.size());
lefts.Set(nidx, left_[nidx]);
CHECK_LT(nidx, right_.size());
rights.Set(nidx, right_[nidx]);
CHECK_LT(nidx, parent_.size());
parents.Set(nidx, parent_[nidx]);
CHECK_LT(nidx, split_index_.size());
indices_array.Set(nidx, split_index_[nidx]);
conds.Set(nidx, split_conds_[nidx]);
default_left.Set(nidx, default_left_[nidx]);
auto in_weight = this->NodeWeight(nidx);
auto weight_out = common::Span<float>(weights.GetArray())
.subspan(nidx * this->NumTarget(), this->NumTarget());
CHECK_EQ(in_weight.Size(), weight_out.size());
std::copy_n(in_weight.Values().data(), in_weight.Size(), weight_out.data());
}
};
namespace tf = tree_field;
if (this->param_->num_feature >
static_cast<bst_feature_t>(std::numeric_limits<std::int32_t>::max())) {
I64Array indices_64(n_nodes);
save_tree(&indices_64);
out[tf::kSplitIdx] = std::move(indices_64);
} else {
I32Array indices_32(n_nodes);
save_tree(&indices_32);
out[tf::kSplitIdx] = std::move(indices_32);
}
out[tf::kBaseWeight] = std::move(weights);
out[tf::kLeft] = std::move(lefts);
out[tf::kRight] = std::move(rights);
out[tf::kParent] = std::move(parents);
out[tf::kSplitCond] = std::move(conds);
out[tf::kDftLeft] = std::move(default_left);
}
void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented();
auto const next_nidx = nidx + 1;
CHECK_EQ(weight.Size(), this->NumTarget());
CHECK_GE(weights_.size(), next_nidx * weight.Size());
auto out_weight = common::Span<float>(weights_).subspan(nidx * weight.Size(), weight.Size());
for (std::size_t i = 0; i < weight.Size(); ++i) {
out_weight[i] = weight(i);
}
}
void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond,
bool default_left, linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight) {
CHECK(this->IsLeaf(nidx));
CHECK_GE(parent_.size(), 1);
CHECK_EQ(parent_.size(), left_.size());
CHECK_EQ(left_.size(), right_.size());
std::size_t n = param_->num_nodes + 2;
CHECK_LT(split_idx, this->param_->num_feature);
left_.resize(n, InvalidNodeId());
right_.resize(n, InvalidNodeId());
parent_.resize(n, InvalidNodeId());
auto left_child = parent_.size() - 2;
auto right_child = parent_.size() - 1;
left_[nidx] = left_child;
right_[nidx] = right_child;
if (nidx != 0) {
CHECK_NE(parent_[nidx], InvalidNodeId());
}
parent_[left_child] = nidx;
parent_[right_child] = nidx;
split_index_.resize(n);
split_index_[nidx] = split_idx;
split_conds_.resize(n);
split_conds_[nidx] = split_cond;
default_left_.resize(n);
default_left_[nidx] = static_cast<std::uint8_t>(default_left);
weights_.resize(n * this->NumTarget());
auto p_weight = this->NodeWeight(nidx);
CHECK_EQ(p_weight.Size(), base_weight.Size());
auto l_weight = this->NodeWeight(left_child);
CHECK_EQ(l_weight.Size(), left_weight.Size());
auto r_weight = this->NodeWeight(right_child);
CHECK_EQ(r_weight.Size(), right_weight.Size());
for (std::size_t i = 0; i < base_weight.Size(); ++i) {
p_weight(i) = base_weight(i);
l_weight(i) = left_weight(i);
r_weight(i) = right_weight(i);
}
}
bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; }
std::size_t MultiTargetTree::Size() const { return parent_.size(); }
} // namespace xgboost

View File

@ -1,25 +1,27 @@
/*! /**
* Copyright 2015-2022 by Contributors * Copyright 2015-2023 by Contributors
* \file tree_model.cc * \file tree_model.cc
* \brief model structure for tree * \brief model structure for tree
*/ */
#include <dmlc/registry.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/registry.h>
#include <xgboost/tree_model.h>
#include <xgboost/logging.h>
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/tree_model.h>
#include <sstream>
#include <limits>
#include <cmath> #include <cmath>
#include <iomanip> #include <iomanip>
#include <stack> #include <limits>
#include <sstream>
#include <type_traits>
#include "param.h"
#include "../common/common.h"
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/common.h"
#include "../predictor/predict_fn.h" #include "../predictor/predict_fn.h"
#include "io_utils.h" // GetElem
#include "param.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/logging.h"
namespace xgboost { namespace xgboost {
// register tree parameter // register tree parameter
@ -729,12 +731,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
constexpr bst_node_t RegTree::kRoot; constexpr bst_node_t RegTree::kRoot;
std::string RegTree::DumpModel(const FeatureMap& fmap, std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
bool with_stats, CHECK(!IsMultiTarget());
std::string format) const { std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
std::unique_ptr<TreeGenerator> builder {
TreeGenerator::Create(format, fmap, with_stats)
};
builder->BuildTree(*this); builder->BuildTree(*this);
std::string result = builder->Str(); std::string result = builder->Str();
@ -742,6 +741,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap,
} }
bool RegTree::Equal(const RegTree& b) const { bool RegTree::Equal(const RegTree& b) const {
CHECK(!IsMultiTarget());
if (NumExtraNodes() != b.NumExtraNodes()) { if (NumExtraNodes() != b.NumExtraNodes()) {
return false; return false;
} }
@ -758,6 +758,7 @@ bool RegTree::Equal(const RegTree& b) const {
} }
bst_node_t RegTree::GetNumLeaves() const { bst_node_t RegTree::GetNumLeaves() const {
CHECK(!IsMultiTarget());
bst_node_t leaves { 0 }; bst_node_t leaves { 0 };
auto const& self = *this; auto const& self = *this;
this->WalkTree([&leaves, &self](bst_node_t nidx) { this->WalkTree([&leaves, &self](bst_node_t nidx) {
@ -770,6 +771,7 @@ bst_node_t RegTree::GetNumLeaves() const {
} }
bst_node_t RegTree::GetNumSplitNodes() const { bst_node_t RegTree::GetNumSplitNodes() const {
CHECK(!IsMultiTarget());
bst_node_t splits { 0 }; bst_node_t splits { 0 };
auto const& self = *this; auto const& self = *this;
this->WalkTree([&splits, &self](bst_node_t nidx) { this->WalkTree([&splits, &self](bst_node_t nidx) {
@ -787,6 +789,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
bst_float right_leaf_weight, bst_float loss_change, bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum, float sum_hess, float left_sum, float right_sum,
bst_node_t leaf_right_child) { bst_node_t leaf_right_child) {
CHECK(!IsMultiTarget());
int pleft = this->AllocNode(); int pleft = this->AllocNode();
int pright = this->AllocNode(); int pright = this->AllocNode();
auto &node = nodes_[nid]; auto &node = nodes_[nid];
@ -807,11 +810,31 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
this->split_types_.at(nid) = FeatureType::kNumerical; this->split_types_.at(nid) = FeatureType::kNumerical;
} }
void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond,
bool default_left, linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight) {
CHECK(IsMultiTarget());
CHECK_LT(split_index, this->param.num_feature);
CHECK(this->p_mt_tree_);
CHECK_GT(param.size_leaf_vector, 1);
this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight,
right_weight);
split_types_.resize(this->Size(), FeatureType::kNumerical);
split_categories_segments_.resize(this->Size());
this->split_types_.at(nidx) = FeatureType::kNumerical;
this->param.num_nodes = this->p_mt_tree_->Size();
}
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index, void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
common::Span<const uint32_t> split_cat, bool default_left, common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight, bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess, bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum) { float left_sum, float right_sum) {
CHECK(!IsMultiTarget());
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(), this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
default_left, base_weight, default_left, base_weight,
left_leaf_weight, right_leaf_weight, loss_change, sum_hess, left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
@ -893,44 +916,17 @@ void RegTree::Save(dmlc::Stream* fo) const {
} }
} }
} }
// typed array, not boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && !std::is_same<JT, Boolean>::value, T> GetElem(
std::vector<T> const& arr, size_t i) {
return arr[i];
}
// typed array boolean
template <typename JT, typename T>
std::enable_if_t<!std::is_same<T, Json>::value && std::is_same<T, uint8_t>::value &&
std::is_same<JT, Boolean>::value,
bool>
GetElem(std::vector<T> const& arr, size_t i) {
return arr[i] == 1;
}
// json array
template <typename JT, typename T>
std::enable_if_t<
std::is_same<T, Json>::value,
std::conditional_t<std::is_same<JT, Integer>::value, int64_t,
std::conditional_t<std::is_same<Boolean, JT>::value, bool, float>>>
GetElem(std::vector<T> const& arr, size_t i) {
if (std::is_same<JT, Boolean>::value && !IsA<Boolean>(arr[i])) {
return get<Integer const>(arr[i]) == 1;
}
return get<JT const>(arr[i]);
}
template <bool typed> template <bool typed>
void RegTree::LoadCategoricalSplit(Json const& in) { void RegTree::LoadCategoricalSplit(Json const& in) {
using I64ArrayT = std::conditional_t<typed, I64Array const, Array const>; auto const& categories_segments = get<I64ArrayT<typed>>(in["categories_segments"]);
using I32ArrayT = std::conditional_t<typed, I32Array const, Array const>; auto const& categories_sizes = get<I64ArrayT<typed>>(in["categories_sizes"]);
auto const& categories_nodes = get<I32ArrayT<typed>>(in["categories_nodes"]);
auto const& categories = get<I32ArrayT<typed>>(in["categories"]);
auto const& categories_segments = get<I64ArrayT>(in["categories_segments"]); auto split_type = get<U8ArrayT<typed>>(in["split_type"]);
auto const& categories_sizes = get<I64ArrayT>(in["categories_sizes"]); bst_node_t n_nodes = split_type.size();
auto const& categories_nodes = get<I32ArrayT>(in["categories_nodes"]); std::size_t cnt = 0;
auto const& categories = get<I32ArrayT>(in["categories"]);
size_t cnt = 0;
bst_node_t last_cat_node = -1; bst_node_t last_cat_node = -1;
if (!categories_nodes.empty()) { if (!categories_nodes.empty()) {
last_cat_node = GetElem<Integer>(categories_nodes, cnt); last_cat_node = GetElem<Integer>(categories_nodes, cnt);
@ -938,7 +934,10 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
// `categories_segments' is only available for categorical nodes to prevent overhead for // `categories_segments' is only available for categorical nodes to prevent overhead for
// numerical node. As a result, we need to track the categorical nodes we have processed // numerical node. As a result, we need to track the categorical nodes we have processed
// so far. // so far.
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) { split_types_.resize(n_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(n_nodes);
for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) {
split_types_[nidx] = static_cast<FeatureType>(GetElem<Integer>(split_type, nidx));
if (nidx == last_cat_node) { if (nidx == last_cat_node) {
auto j_begin = GetElem<Integer>(categories_segments, cnt); auto j_begin = GetElem<Integer>(categories_segments, cnt);
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin; auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
@ -985,15 +984,17 @@ template void RegTree::LoadCategoricalSplit<false>(Json const& in);
void RegTree::SaveCategoricalSplit(Json* p_out) const { void RegTree::SaveCategoricalSplit(Json* p_out) const {
auto& out = *p_out; auto& out = *p_out;
CHECK_EQ(this->split_types_.size(), param.num_nodes); CHECK_EQ(this->split_types_.size(), this->Size());
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); CHECK_EQ(this->GetSplitCategoriesPtr().size(), this->Size());
I64Array categories_segments; I64Array categories_segments;
I64Array categories_sizes; I64Array categories_sizes;
I32Array categories; // bst_cat_t = int32_t I32Array categories; // bst_cat_t = int32_t
I32Array categories_nodes; // bst_note_t = int32_t I32Array categories_nodes; // bst_note_t = int32_t
U8Array split_type(split_types_.size());
for (size_t i = 0; i < nodes_.size(); ++i) { for (size_t i = 0; i < nodes_.size(); ++i) {
split_type.Set(i, static_cast<std::underlying_type_t<FeatureType>>(this->NodeSplitType(i)));
if (this->split_types_[i] == FeatureType::kCategorical) { if (this->split_types_[i] == FeatureType::kCategorical) {
categories_nodes.GetArray().emplace_back(i); categories_nodes.GetArray().emplace_back(i);
auto begin = categories.Size(); auto begin = categories.Size();
@ -1012,66 +1013,49 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
} }
} }
out["split_type"] = std::move(split_type);
out["categories_segments"] = std::move(categories_segments); out["categories_segments"] = std::move(categories_segments);
out["categories_sizes"] = std::move(categories_sizes); out["categories_sizes"] = std::move(categories_sizes);
out["categories_nodes"] = std::move(categories_nodes); out["categories_nodes"] = std::move(categories_nodes);
out["categories"] = std::move(categories); out["categories"] = std::move(categories);
} }
template <bool typed, bool feature_is_64, template <bool typed, bool feature_is_64>
typename FloatArrayT = std::conditional_t<typed, F32Array const, Array const>, void LoadModelImpl(Json const& in, TreeParam const& param, std::vector<RTreeNodeStat>* p_stats,
typename U8ArrayT = std::conditional_t<typed, U8Array const, Array const>, std::vector<RegTree::Node>* p_nodes) {
typename I32ArrayT = std::conditional_t<typed, I32Array const, Array const>, namespace tf = tree_field;
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>,
typename IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT, I32ArrayT>>
bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>* p_stats,
std::vector<FeatureType>* p_split_types, std::vector<RegTree::Node>* p_nodes,
std::vector<RegTree::Segment>* p_split_categories_segments) {
auto& stats = *p_stats; auto& stats = *p_stats;
auto& split_types = *p_split_types;
auto& nodes = *p_nodes; auto& nodes = *p_nodes;
auto& split_categories_segments = *p_split_categories_segments;
FromJson(in["tree_param"], param); auto n_nodes = param.num_nodes;
auto n_nodes = param->num_nodes;
CHECK_NE(n_nodes, 0); CHECK_NE(n_nodes, 0);
// stats // stats
auto const& loss_changes = get<FloatArrayT>(in["loss_changes"]); auto const& loss_changes = get<FloatArrayT<typed>>(in[tf::kLossChg]);
CHECK_EQ(loss_changes.size(), n_nodes); CHECK_EQ(loss_changes.size(), n_nodes);
auto const& sum_hessian = get<FloatArrayT>(in["sum_hessian"]); auto const& sum_hessian = get<FloatArrayT<typed>>(in[tf::kSumHess]);
CHECK_EQ(sum_hessian.size(), n_nodes); CHECK_EQ(sum_hessian.size(), n_nodes);
auto const& base_weights = get<FloatArrayT>(in["base_weights"]); auto const& base_weights = get<FloatArrayT<typed>>(in[tf::kBaseWeight]);
CHECK_EQ(base_weights.size(), n_nodes); CHECK_EQ(base_weights.size(), n_nodes);
// nodes // nodes
auto const& lefts = get<I32ArrayT>(in["left_children"]); auto const& lefts = get<I32ArrayT<typed>>(in[tf::kLeft]);
CHECK_EQ(lefts.size(), n_nodes); CHECK_EQ(lefts.size(), n_nodes);
auto const& rights = get<I32ArrayT>(in["right_children"]); auto const& rights = get<I32ArrayT<typed>>(in[tf::kRight]);
CHECK_EQ(rights.size(), n_nodes); CHECK_EQ(rights.size(), n_nodes);
auto const& parents = get<I32ArrayT>(in["parents"]); auto const& parents = get<I32ArrayT<typed>>(in[tf::kParent]);
CHECK_EQ(parents.size(), n_nodes); CHECK_EQ(parents.size(), n_nodes);
auto const& indices = get<IndexArrayT>(in["split_indices"]); auto const& indices = get<IndexArrayT<typed, feature_is_64>>(in[tf::kSplitIdx]);
CHECK_EQ(indices.size(), n_nodes); CHECK_EQ(indices.size(), n_nodes);
auto const& conds = get<FloatArrayT>(in["split_conditions"]); auto const& conds = get<FloatArrayT<typed>>(in[tf::kSplitCond]);
CHECK_EQ(conds.size(), n_nodes); CHECK_EQ(conds.size(), n_nodes);
auto const& default_left = get<U8ArrayT>(in["default_left"]); auto const& default_left = get<U8ArrayT<typed>>(in[tf::kDftLeft]);
CHECK_EQ(default_left.size(), n_nodes); CHECK_EQ(default_left.size(), n_nodes);
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
std::remove_const_t<std::remove_reference_t<decltype(get<U8ArrayT const>(in["split_type"]))>>
split_type;
if (has_cat) {
split_type = get<U8ArrayT const>(in["split_type"]);
}
// Initialization // Initialization
stats = std::remove_reference_t<decltype(stats)>(n_nodes); stats = std::remove_reference_t<decltype(stats)>(n_nodes);
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes); nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
split_categories_segments = std::remove_reference_t<decltype(split_categories_segments)>(n_nodes);
static_assert(std::is_integral<decltype(GetElem<Integer>(lefts, 0))>::value); static_assert(std::is_integral<decltype(GetElem<Integer>(lefts, 0))>::value);
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value); static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value);
CHECK_EQ(n_nodes, split_categories_segments.size());
// Set node // Set node
for (int32_t i = 0; i < n_nodes; ++i) { for (int32_t i = 0; i < n_nodes; ++i) {
@ -1088,41 +1072,46 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
float cond{GetElem<Number>(conds, i)}; float cond{GetElem<Number>(conds, i)};
bool dft_left{GetElem<Boolean>(default_left, i)}; bool dft_left{GetElem<Boolean>(default_left, i)};
n = RegTree::Node{left, right, parent, ind, cond, dft_left}; n = RegTree::Node{left, right, parent, ind, cond, dft_left};
if (has_cat) {
split_types[i] = static_cast<FeatureType>(GetElem<Integer>(split_type, i));
} }
} }
return has_cat;
}
void RegTree::LoadModel(Json const& in) { void RegTree::LoadModel(Json const& in) {
bool has_cat{false}; namespace tf = tree_field;
bool typed = IsA<F32Array>(in["loss_changes"]);
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
if (typed && feature_is_64) {
has_cat = LoadModelImpl<true, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (typed && !feature_is_64) {
has_cat = LoadModelImpl<true, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (!typed && feature_is_64) {
has_cat = LoadModelImpl<false, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else {
has_cat = LoadModelImpl<false, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
}
bool typed = IsA<I32Array>(in[tf::kParent]);
auto const& in_obj = get<Object const>(in);
// basic properties
FromJson(in["tree_param"], &param);
// categorical splits
bool has_cat = in_obj.find("split_type") != in_obj.cend();
if (has_cat) { if (has_cat) {
if (typed) { if (typed) {
this->LoadCategoricalSplit<true>(in); this->LoadCategoricalSplit<true>(in);
} else { } else {
this->LoadCategoricalSplit<false>(in); this->LoadCategoricalSplit<false>(in);
} }
}
// multi-target
if (param.size_leaf_vector > 1) {
this->p_mt_tree_.reset(new MultiTargetTree{&param});
this->GetMultiTargetTree()->LoadModel(in);
return;
}
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
if (typed && feature_is_64) {
LoadModelImpl<true, true>(in, param, &stats_, &nodes_);
} else if (typed && !feature_is_64) {
LoadModelImpl<true, false>(in, param, &stats_, &nodes_);
} else if (!typed && feature_is_64) {
LoadModelImpl<false, true>(in, param, &stats_, &nodes_);
} else { } else {
LoadModelImpl<false, false>(in, param, &stats_, &nodes_);
}
if (!has_cat) {
this->split_categories_segments_.resize(this->param.num_nodes); this->split_categories_segments_.resize(this->param.num_nodes);
this->split_types_.resize(this->param.num_nodes);
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
} }
@ -1144,16 +1133,26 @@ void RegTree::LoadModel(Json const& in) {
} }
void RegTree::SaveModel(Json* p_out) const { void RegTree::SaveModel(Json* p_out) const {
auto& out = *p_out;
// basic properties
out["tree_param"] = ToJson(param);
// categorical splits
this->SaveCategoricalSplit(p_out);
// multi-target
if (this->IsMultiTarget()) {
CHECK_GT(param.size_leaf_vector, 1);
this->GetMultiTargetTree()->SaveModel(p_out);
return;
}
/* Here we are treating leaf node and internal node equally. Some information like /* Here we are treating leaf node and internal node equally. Some information like
* child node id doesn't make sense for leaf node but we will have to save them to * child node id doesn't make sense for leaf node but we will have to save them to
* avoid creating a huge map. One difficulty is XGBoost has deleted node created by * avoid creating a huge map. One difficulty is XGBoost has deleted node created by
* pruner, and this pruner can be used inside another updater so leaf are not necessary * pruner, and this pruner can be used inside another updater so leaf are not necessary
* at the end of node array. * at the end of node array.
*/ */
auto& out = *p_out;
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size())); CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
out["tree_param"] = ToJson(param);
CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes)); CHECK_EQ(get<String>(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes));
auto n_nodes = param.num_nodes; auto n_nodes = param.num_nodes;
@ -1167,12 +1166,12 @@ void RegTree::SaveModel(Json* p_out) const {
I32Array rights(n_nodes); I32Array rights(n_nodes);
I32Array parents(n_nodes); I32Array parents(n_nodes);
F32Array conds(n_nodes); F32Array conds(n_nodes);
U8Array default_left(n_nodes); U8Array default_left(n_nodes);
U8Array split_type(n_nodes);
CHECK_EQ(this->split_types_.size(), param.num_nodes); CHECK_EQ(this->split_types_.size(), param.num_nodes);
namespace tf = tree_field;
auto save_tree = [&](auto* p_indices_array) { auto save_tree = [&](auto* p_indices_array) {
auto& indices_array = *p_indices_array; auto& indices_array = *p_indices_array;
for (bst_node_t i = 0; i < n_nodes; ++i) { for (bst_node_t i = 0; i < n_nodes; ++i) {
@ -1188,33 +1187,28 @@ void RegTree::SaveModel(Json* p_out) const {
indices_array.Set(i, n.SplitIndex()); indices_array.Set(i, n.SplitIndex());
conds.Set(i, n.SplitCond()); conds.Set(i, n.SplitCond());
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft())); default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));
split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
} }
}; };
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) { if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
I64Array indices_64(n_nodes); I64Array indices_64(n_nodes);
save_tree(&indices_64); save_tree(&indices_64);
out["split_indices"] = std::move(indices_64); out[tf::kSplitIdx] = std::move(indices_64);
} else { } else {
I32Array indices_32(n_nodes); I32Array indices_32(n_nodes);
save_tree(&indices_32); save_tree(&indices_32);
out["split_indices"] = std::move(indices_32); out[tf::kSplitIdx] = std::move(indices_32);
} }
this->SaveCategoricalSplit(&out); out[tf::kLossChg] = std::move(loss_changes);
out[tf::kSumHess] = std::move(sum_hessian);
out[tf::kBaseWeight] = std::move(base_weights);
out["split_type"] = std::move(split_type); out[tf::kLeft] = std::move(lefts);
out["loss_changes"] = std::move(loss_changes); out[tf::kRight] = std::move(rights);
out["sum_hessian"] = std::move(sum_hessian); out[tf::kParent] = std::move(parents);
out["base_weights"] = std::move(base_weights);
out["left_children"] = std::move(lefts); out[tf::kSplitCond] = std::move(conds);
out["right_children"] = std::move(rights); out[tf::kDftLeft] = std::move(default_left);
out["parents"] = std::move(parents);
out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left);
} }
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,

View File

@ -445,7 +445,7 @@ struct GPUHistMakerDevice {
dh::caching_device_vector<FeatureType> d_split_types; dh::caching_device_vector<FeatureType> d_split_types;
dh::caching_device_vector<uint32_t> d_categories; dh::caching_device_vector<uint32_t> d_categories;
dh::caching_device_vector<RegTree::Segment> d_categories_segments; dh::caching_device_vector<RegTree::CategoricalSplitMatrix::Segment> d_categories_segments;
if (!categories.empty()) { if (!categories.empty()) {
dh::CopyToD(h_split_types, &d_split_types); dh::CopyToD(h_split_types, &d_split_types);
@ -458,11 +458,10 @@ struct GPUHistMakerDevice {
p_out_position); p_out_position);
} }
void FinalisePositionInPage(EllpackPageImpl const *page, void FinalisePositionInPage(
const common::Span<RegTree::Node> d_nodes, EllpackPageImpl const* page, const common::Span<RegTree::Node> d_nodes,
common::Span<FeatureType const> d_feature_types, common::Span<FeatureType const> d_feature_types, common::Span<uint32_t const> categories,
common::Span<uint32_t const> categories, common::Span<RegTree::CategoricalSplitMatrix::Segment> categories_segments,
common::Span<RegTree::Segment> categories_segments,
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto d_gpair = this->gpair; auto d_gpair = this->gpair;

View File

@ -0,0 +1,48 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <xgboost/multi_target_tree_model.h>
#include <xgboost/tree_model.h> // for RegTree
namespace xgboost {
TEST(MultiTargetTree, JsonIO) {
bst_target_t n_targets{3};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
ASSERT_TRUE(tree.IsMultiTarget());
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, Context::kCpuId};
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, Context::kCpuId};
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
left_weight.HostView(), right_weight.HostView());
ASSERT_EQ(tree.param.num_nodes, 3);
ASSERT_EQ(tree.param.size_leaf_vector, 3);
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
ASSERT_EQ(tree.Size(), 3);
Json jtree{Object{}};
tree.SaveModel(&jtree);
auto check_jtree = [](Json jtree, RegTree const& tree) {
ASSERT_EQ(get<String const>(jtree["tree_param"]["num_nodes"]),
std::to_string(tree.param.num_nodes));
ASSERT_EQ(get<F32Array const>(jtree["base_weights"]).size(),
tree.param.num_nodes * tree.param.size_leaf_vector);
ASSERT_EQ(get<I32Array const>(jtree["parents"]).size(), tree.param.num_nodes);
ASSERT_EQ(get<I32Array const>(jtree["left_children"]).size(), tree.param.num_nodes);
ASSERT_EQ(get<I32Array const>(jtree["right_children"]).size(), tree.param.num_nodes);
};
check_jtree(jtree, tree);
RegTree loaded;
loaded.LoadModel(jtree);
ASSERT_TRUE(loaded.IsMultiTarget());
ASSERT_EQ(loaded.param.num_nodes, 3);
Json jtree1{Object{}};
loaded.SaveModel(&jtree1);
check_jtree(jtree1, tree);
}
} // namespace xgboost

View File

@ -477,7 +477,7 @@ TEST(Tree, JsonIO) {
auto tparam = j_tree["tree_param"]; auto tparam = j_tree["tree_param"];
ASSERT_EQ(get<String>(tparam["num_feature"]), "0"); ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3"); ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0"); ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "1");
ASSERT_EQ(get<I32Array const>(j_tree["left_children"]).size(), 3ul); ASSERT_EQ(get<I32Array const>(j_tree["left_children"]).size(), 3ul);
ASSERT_EQ(get<I32Array const>(j_tree["right_children"]).size(), 3ul); ASSERT_EQ(get<I32Array const>(j_tree["right_children"]).size(), 3ul);