merge latest changes
This commit is contained in:
@@ -110,11 +110,11 @@ using bst_bin_t = int32_t; // NOLINT
|
||||
*/
|
||||
using bst_row_t = std::size_t; // NOLINT
|
||||
/*! \brief Type for tree node index. */
|
||||
using bst_node_t = int32_t; // NOLINT
|
||||
using bst_node_t = std::int32_t; // NOLINT
|
||||
/*! \brief Type for ranking group index. */
|
||||
using bst_group_t = uint32_t; // NOLINT
|
||||
/*! \brief Type for indexing target variables. */
|
||||
using bst_target_t = std::size_t; // NOLINT
|
||||
using bst_group_t = std::uint32_t; // NOLINT
|
||||
/*! \brief Type for indexing into output targets. */
|
||||
using bst_target_t = std::uint32_t; // NOLINT
|
||||
|
||||
namespace detail {
|
||||
/*! \brief Implementation of gradient statistics pair. Template specialisation
|
||||
|
||||
@@ -8,29 +8,33 @@
|
||||
#ifndef XGBOOST_LEARNER_H_
|
||||
#define XGBOOST_LEARNER_H_
|
||||
|
||||
#include <dmlc/io.h> // Serializable
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h> // Tensor
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <dmlc/io.h> // for Serializable
|
||||
#include <xgboost/base.h> // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/linalg.h> // for Tensor, TensorView
|
||||
#include <xgboost/metric.h> // for Metric
|
||||
#include <xgboost/model.h> // for Configurable, Model
|
||||
#include <xgboost/span.h> // for Span
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm> // for max
|
||||
#include <cstdint> // for int32_t, uint32_t, uint8_t
|
||||
#include <map> // for map
|
||||
#include <memory> // for shared_ptr, unique_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class FeatureMap;
|
||||
class Metric;
|
||||
class GradientBooster;
|
||||
class ObjFunction;
|
||||
class DMatrix;
|
||||
class Json;
|
||||
struct XGBAPIThreadLocalEntry;
|
||||
template <typename T>
|
||||
class HostDeviceVector;
|
||||
|
||||
enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kValue = 0,
|
||||
@@ -143,7 +147,10 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \brief Get number of boosted rounds from gradient booster.
|
||||
*/
|
||||
virtual int32_t BoostedRounds() const = 0;
|
||||
virtual uint32_t Groups() const = 0;
|
||||
/**
|
||||
* \brief Get the number of output groups from the model.
|
||||
*/
|
||||
virtual std::uint32_t Groups() const = 0;
|
||||
|
||||
void LoadModel(Json const& in) override = 0;
|
||||
void SaveModel(Json* out) const override = 0;
|
||||
@@ -275,8 +282,16 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
|
||||
struct LearnerModelParamLegacy;
|
||||
|
||||
/*
|
||||
* \brief Basic Model Parameters, used to describe the booster.
|
||||
/**
|
||||
* \brief Strategy for building multi-target models.
|
||||
*/
|
||||
enum class MultiStrategy : std::int32_t {
|
||||
kComposite = 0,
|
||||
kMonolithic = 1,
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Basic model parameters, used to describe the booster.
|
||||
*/
|
||||
struct LearnerModelParam {
|
||||
private:
|
||||
@@ -287,30 +302,51 @@ struct LearnerModelParam {
|
||||
linalg::Tensor<float, 1> base_score_;
|
||||
|
||||
public:
|
||||
/* \brief number of features */
|
||||
uint32_t num_feature { 0 };
|
||||
/* \brief number of classes, if it is multi-class classification */
|
||||
uint32_t num_output_group { 0 };
|
||||
/* \brief Current task, determined by objective. */
|
||||
/**
|
||||
* \brief The number of features.
|
||||
*/
|
||||
bst_feature_t num_feature{0};
|
||||
/**
|
||||
* \brief The number of classes or targets.
|
||||
*/
|
||||
std::uint32_t num_output_group{0};
|
||||
/**
|
||||
* \brief Current task, determined by objective.
|
||||
*/
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
/**
|
||||
* \brief Strategy for building multi-target models.
|
||||
*/
|
||||
MultiStrategy multi_strategy{MultiStrategy::kComposite};
|
||||
|
||||
LearnerModelParam() = default;
|
||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||
// this one as an immutable copy.
|
||||
LearnerModelParam(Context const* ctx, LearnerModelParamLegacy const& user_param,
|
||||
linalg::Tensor<float, 1> base_margin, ObjInfo t);
|
||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t);
|
||||
LearnerModelParam(bst_feature_t n_features, linalg::Tensor<float, 1> base_margin,
|
||||
uint32_t n_groups)
|
||||
: base_score_{std::move(base_margin)}, num_feature{n_features}, num_output_group{n_groups} {}
|
||||
linalg::Tensor<float, 1> base_margin, ObjInfo t, MultiStrategy multi_strategy);
|
||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t,
|
||||
MultiStrategy multi_strategy);
|
||||
LearnerModelParam(bst_feature_t n_features, linalg::Tensor<float, 1> base_score,
|
||||
std::uint32_t n_groups, bst_target_t n_targets, MultiStrategy multi_strategy)
|
||||
: base_score_{std::move(base_score)},
|
||||
num_feature{n_features},
|
||||
num_output_group{std::max(n_groups, n_targets)},
|
||||
multi_strategy{multi_strategy} {}
|
||||
|
||||
linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
|
||||
linalg::TensorView<float const, 1> BaseScore(int32_t device) const;
|
||||
[[nodiscard]] linalg::TensorView<float const, 1> BaseScore(std::int32_t device) const;
|
||||
|
||||
void Copy(LearnerModelParam const& that);
|
||||
[[nodiscard]] bool IsVectorLeaf() const noexcept {
|
||||
return multi_strategy == MultiStrategy::kMonolithic;
|
||||
}
|
||||
[[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; }
|
||||
[[nodiscard]] bst_target_t LeafLength() const noexcept {
|
||||
return this->IsVectorLeaf() ? this->OutputLength() : 1;
|
||||
}
|
||||
|
||||
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
|
||||
bool Initialized() const { return num_feature != 0 && num_output_group != 0; }
|
||||
[[nodiscard]] bool Initialized() const { return num_feature != 0 && num_output_group != 0; }
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
96
include/xgboost/multi_target_tree_model.h
Normal file
96
include/xgboost/multi_target_tree_model.h
Normal file
@@ -0,0 +1,96 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*
|
||||
* \brief Core data structure for multi-target trees.
|
||||
*/
|
||||
#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
|
||||
#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
|
||||
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/linalg.h> // for VectorView
|
||||
#include <xgboost/model.h> // for Model
|
||||
#include <xgboost/span.h> // for Span
|
||||
|
||||
#include <cinttypes> // for uint8_t
|
||||
#include <cstddef> // for size_t
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost {
|
||||
struct TreeParam;
|
||||
/**
|
||||
* \brief Tree structure for multi-target model.
|
||||
*/
|
||||
class MultiTargetTree : public Model {
|
||||
public:
|
||||
static bst_node_t constexpr InvalidNodeId() { return -1; }
|
||||
|
||||
private:
|
||||
TreeParam const* param_;
|
||||
std::vector<bst_node_t> left_;
|
||||
std::vector<bst_node_t> right_;
|
||||
std::vector<bst_node_t> parent_;
|
||||
std::vector<bst_feature_t> split_index_;
|
||||
std::vector<std::uint8_t> default_left_;
|
||||
std::vector<float> split_conds_;
|
||||
std::vector<float> weights_;
|
||||
|
||||
[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
|
||||
auto beg = nidx * this->NumTarget();
|
||||
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
|
||||
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
|
||||
}
|
||||
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
|
||||
auto beg = nidx * this->NumTarget();
|
||||
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
|
||||
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
|
||||
}
|
||||
|
||||
public:
|
||||
explicit MultiTargetTree(TreeParam const* param);
|
||||
/**
|
||||
* \brief Set the weight for a leaf.
|
||||
*/
|
||||
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
|
||||
/**
|
||||
* \brief Expand a leaf into split node.
|
||||
*/
|
||||
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
|
||||
linalg::VectorView<float const> base_weight,
|
||||
linalg::VectorView<float const> left_weight,
|
||||
linalg::VectorView<float const> right_weight);
|
||||
|
||||
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
|
||||
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
|
||||
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
|
||||
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
|
||||
|
||||
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
|
||||
[[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
|
||||
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
|
||||
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
|
||||
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
|
||||
}
|
||||
|
||||
[[nodiscard]] bst_target_t NumTarget() const;
|
||||
|
||||
[[nodiscard]] std::size_t Size() const;
|
||||
|
||||
[[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
|
||||
bst_node_t depth{0};
|
||||
while (Parent(nidx) != InvalidNodeId()) {
|
||||
++depth;
|
||||
nidx = Parent(nidx);
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
|
||||
[[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
|
||||
CHECK(IsLeaf(nidx));
|
||||
return this->NodeWeight(nidx);
|
||||
}
|
||||
|
||||
void LoadModel(Json const& in) override;
|
||||
void SaveModel(Json* out) const override;
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by Contributors
|
||||
* \file tree_model.h
|
||||
* \brief model structure for tree
|
||||
* \author Tianqi Chen
|
||||
@@ -9,60 +9,62 @@
|
||||
|
||||
#include <dmlc/io.h>
|
||||
#include <dmlc/parameter.h>
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/linalg.h> // for VectorView
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <memory> // for make_unique
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
struct PathElement; // forward declaration
|
||||
|
||||
class Json;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#define XGBOOST_NODISCARD
|
||||
#else
|
||||
#define XGBOOST_NODISCARD [[nodiscard]]
|
||||
#endif
|
||||
// FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
|
||||
// not be configured by users.
|
||||
/*! \brief meta parameters of the tree */
|
||||
struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
/*! \brief (Deprecated) number of start root */
|
||||
int deprecated_num_roots;
|
||||
int deprecated_num_roots{1};
|
||||
/*! \brief total number of nodes */
|
||||
int num_nodes;
|
||||
int num_nodes{1};
|
||||
/*!\brief number of deleted nodes */
|
||||
int num_deleted;
|
||||
int num_deleted{0};
|
||||
/*! \brief maximum depth, this is a statistics of the tree */
|
||||
int deprecated_max_depth;
|
||||
int deprecated_max_depth{0};
|
||||
/*! \brief number of features used for tree construction */
|
||||
bst_feature_t num_feature;
|
||||
bst_feature_t num_feature{0};
|
||||
/*!
|
||||
* \brief leaf vector size, used for vector tree
|
||||
* used to store more than one dimensional information in tree
|
||||
*/
|
||||
int size_leaf_vector;
|
||||
bst_target_t size_leaf_vector{1};
|
||||
/*! \brief reserved part, make sure alignment works for 64bit */
|
||||
int reserved[31];
|
||||
/*! \brief constructor */
|
||||
TreeParam() {
|
||||
// assert compact alignment
|
||||
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
|
||||
"TreeParam: 64 bit align");
|
||||
std::memset(this, 0, sizeof(TreeParam));
|
||||
num_nodes = 1;
|
||||
deprecated_num_roots = 1;
|
||||
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align");
|
||||
std::memset(reserved, 0, sizeof(reserved));
|
||||
}
|
||||
|
||||
// Swap byte order for all fields. Useful for transporting models between machines with different
|
||||
// endianness (big endian vs little endian)
|
||||
inline TreeParam ByteSwap() const {
|
||||
XGBOOST_NODISCARD TreeParam ByteSwap() const {
|
||||
TreeParam x = *this;
|
||||
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
|
||||
dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
|
||||
@@ -80,17 +82,18 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
// other arguments are set by the algorithm.
|
||||
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
|
||||
DMLC_DECLARE_FIELD(num_feature)
|
||||
.set_default(0)
|
||||
.describe("Number of features used in tree construction.");
|
||||
DMLC_DECLARE_FIELD(num_deleted);
|
||||
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
|
||||
DMLC_DECLARE_FIELD(num_deleted).set_default(0);
|
||||
DMLC_DECLARE_FIELD(size_leaf_vector)
|
||||
.set_lower_bound(0)
|
||||
.set_default(1)
|
||||
.describe("Size of leaf vector, reserved for vector tree");
|
||||
}
|
||||
|
||||
bool operator==(const TreeParam& b) const {
|
||||
return num_nodes == b.num_nodes &&
|
||||
num_deleted == b.num_deleted &&
|
||||
num_feature == b.num_feature &&
|
||||
size_leaf_vector == b.size_leaf_vector;
|
||||
return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
|
||||
num_feature == b.num_feature && size_leaf_vector == b.size_leaf_vector;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -114,7 +117,7 @@ struct RTreeNodeStat {
|
||||
}
|
||||
// Swap byte order for all fields. Useful for transporting models between machines with different
|
||||
// endianness (big endian vs little endian)
|
||||
inline RTreeNodeStat ByteSwap() const {
|
||||
XGBOOST_NODISCARD RTreeNodeStat ByteSwap() const {
|
||||
RTreeNodeStat x = *this;
|
||||
dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
|
||||
dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
|
||||
@@ -124,16 +127,45 @@ struct RTreeNodeStat {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief Helper for defining copyable data structure that contains unique pointers.
|
||||
*/
|
||||
template <typename T>
|
||||
class CopyUniquePtr {
|
||||
std::unique_ptr<T> ptr_{nullptr};
|
||||
|
||||
public:
|
||||
CopyUniquePtr() = default;
|
||||
CopyUniquePtr(CopyUniquePtr const& that) {
|
||||
ptr_.reset(nullptr);
|
||||
if (that.ptr_) {
|
||||
ptr_ = std::make_unique<T>(*that);
|
||||
}
|
||||
}
|
||||
T* get() const noexcept { return ptr_.get(); } // NOLINT
|
||||
|
||||
T& operator*() { return *ptr_; }
|
||||
T* operator->() noexcept { return this->get(); }
|
||||
|
||||
T const& operator*() const { return *ptr_; }
|
||||
T const* operator->() const noexcept { return this->get(); }
|
||||
|
||||
explicit operator bool() const { return static_cast<bool>(ptr_); }
|
||||
bool operator!() const { return !ptr_; }
|
||||
void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief define regression tree to be the most common tree model.
|
||||
*
|
||||
* This is the data structure used in xgboost's major tree models.
|
||||
*/
|
||||
class RegTree : public Model {
|
||||
public:
|
||||
using SplitCondT = bst_float;
|
||||
static constexpr bst_node_t kInvalidNodeId {-1};
|
||||
static constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()};
|
||||
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
|
||||
static constexpr bst_node_t kRoot { 0 };
|
||||
static constexpr bst_node_t kRoot{0};
|
||||
|
||||
/*! \brief tree node */
|
||||
class Node {
|
||||
@@ -151,51 +183,51 @@ class RegTree : public Model {
|
||||
}
|
||||
|
||||
/*! \brief index of left child */
|
||||
XGBOOST_DEVICE int LeftChild() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int LeftChild() const {
|
||||
return this->cleft_;
|
||||
}
|
||||
/*! \brief index of right child */
|
||||
XGBOOST_DEVICE int RightChild() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int RightChild() const {
|
||||
return this->cright_;
|
||||
}
|
||||
/*! \brief index of default child when feature is missing */
|
||||
XGBOOST_DEVICE int DefaultChild() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int DefaultChild() const {
|
||||
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
|
||||
}
|
||||
/*! \brief feature index of split condition */
|
||||
XGBOOST_DEVICE unsigned SplitIndex() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD unsigned SplitIndex() const {
|
||||
return sindex_ & ((1U << 31) - 1U);
|
||||
}
|
||||
/*! \brief when feature is unknown, whether goes to left child */
|
||||
XGBOOST_DEVICE bool DefaultLeft() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool DefaultLeft() const {
|
||||
return (sindex_ >> 31) != 0;
|
||||
}
|
||||
/*! \brief whether current node is leaf node */
|
||||
XGBOOST_DEVICE bool IsLeaf() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsLeaf() const {
|
||||
return cleft_ == kInvalidNodeId;
|
||||
}
|
||||
/*! \return get leaf value of leaf node */
|
||||
XGBOOST_DEVICE bst_float LeafValue() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD float LeafValue() const {
|
||||
return (this->info_).leaf_value;
|
||||
}
|
||||
/*! \return get split condition of the node */
|
||||
XGBOOST_DEVICE SplitCondT SplitCond() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD SplitCondT SplitCond() const {
|
||||
return (this->info_).split_cond;
|
||||
}
|
||||
/*! \brief get parent of the node */
|
||||
XGBOOST_DEVICE int Parent() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int Parent() const {
|
||||
return parent_ & ((1U << 31) - 1);
|
||||
}
|
||||
/*! \brief whether current node is left child */
|
||||
XGBOOST_DEVICE bool IsLeftChild() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsLeftChild() const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
/*! \brief whether this node is deleted */
|
||||
XGBOOST_DEVICE bool IsDeleted() const {
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsDeleted() const {
|
||||
return sindex_ == kDeletedNodeMarker;
|
||||
}
|
||||
/*! \brief whether current node is root */
|
||||
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
/*!
|
||||
* \brief set the left child
|
||||
* \param nid node id to right child
|
||||
@@ -252,7 +284,7 @@ class RegTree : public Model {
|
||||
info_.leaf_value == b.info_.leaf_value;
|
||||
}
|
||||
|
||||
inline Node ByteSwap() const {
|
||||
XGBOOST_NODISCARD Node ByteSwap() const {
|
||||
Node x = *this;
|
||||
dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
|
||||
dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
|
||||
@@ -312,19 +344,28 @@ class RegTree : public Model {
|
||||
|
||||
/*! \brief model parameter */
|
||||
TreeParam param;
|
||||
/*! \brief constructor */
|
||||
RegTree() {
|
||||
param.num_nodes = 1;
|
||||
param.num_deleted = 0;
|
||||
param.Init(Args{});
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
for (int i = 0; i < param.num_nodes; i ++) {
|
||||
for (int i = 0; i < param.num_nodes; i++) {
|
||||
nodes_[i].SetLeaf(0.0f);
|
||||
nodes_[i].SetParent(kInvalidNodeId);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* \brief Constructor that initializes the tree model with shape.
|
||||
*/
|
||||
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
|
||||
param.num_feature = n_features;
|
||||
param.size_leaf_vector = n_targets;
|
||||
if (n_targets > 1) {
|
||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief get node given nid */
|
||||
Node& operator[](int nid) {
|
||||
return nodes_[nid];
|
||||
@@ -335,17 +376,17 @@ class RegTree : public Model {
|
||||
}
|
||||
|
||||
/*! \brief get const reference to nodes */
|
||||
const std::vector<Node>& GetNodes() const { return nodes_; }
|
||||
XGBOOST_NODISCARD const std::vector<Node>& GetNodes() const { return nodes_; }
|
||||
|
||||
/*! \brief get const reference to stats */
|
||||
const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
|
||||
XGBOOST_NODISCARD const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
|
||||
|
||||
/*! \brief get node statistics given nid */
|
||||
RTreeNodeStat& Stat(int nid) {
|
||||
return stats_[nid];
|
||||
}
|
||||
/*! \brief get node statistics given nid */
|
||||
const RTreeNodeStat& Stat(int nid) const {
|
||||
XGBOOST_NODISCARD const RTreeNodeStat& Stat(int nid) const {
|
||||
return stats_[nid];
|
||||
}
|
||||
|
||||
@@ -398,7 +439,7 @@ class RegTree : public Model {
|
||||
*
|
||||
* \param b The other tree.
|
||||
*/
|
||||
bool Equal(const RegTree& b) const;
|
||||
XGBOOST_NODISCARD bool Equal(const RegTree& b) const;
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes.
|
||||
@@ -424,6 +465,11 @@ class RegTree : public Model {
|
||||
float right_sum,
|
||||
bst_node_t leaf_right_child = kInvalidNodeId);
|
||||
|
||||
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
|
||||
linalg::VectorView<float const> base_weight,
|
||||
linalg::VectorView<float const> left_weight,
|
||||
linalg::VectorView<float const> right_weight);
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node with categories
|
||||
*
|
||||
@@ -445,15 +491,27 @@ class RegTree : public Model {
|
||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||
float left_sum, float right_sum);
|
||||
|
||||
bool HasCategoricalSplit() const {
|
||||
XGBOOST_NODISCARD bool HasCategoricalSplit() const {
|
||||
return !split_categories_.empty();
|
||||
}
|
||||
/**
|
||||
* \brief Whether this is a multi-target tree.
|
||||
*/
|
||||
XGBOOST_NODISCARD bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
|
||||
XGBOOST_NODISCARD bst_target_t NumTargets() const { return param.size_leaf_vector; }
|
||||
XGBOOST_NODISCARD auto GetMultiTargetTree() const {
|
||||
CHECK(IsMultiTarget());
|
||||
return p_mt_tree_.get();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
*/
|
||||
int GetDepth(int nid) const {
|
||||
XGBOOST_NODISCARD std::int32_t GetDepth(bst_node_t nid) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->Depth(nid);
|
||||
}
|
||||
int depth = 0;
|
||||
while (!nodes_[nid].IsRoot()) {
|
||||
++depth;
|
||||
@@ -461,12 +519,16 @@ class RegTree : public Model {
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
|
||||
CHECK(IsMultiTarget());
|
||||
return this->p_mt_tree_->SetLeaf(nidx, weight);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
* \param nid node id
|
||||
*/
|
||||
int MaxDepth(int nid) const {
|
||||
XGBOOST_NODISCARD int MaxDepth(int nid) const {
|
||||
if (nodes_[nid].IsLeaf()) return 0;
|
||||
return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
|
||||
MaxDepth(nodes_[nid].RightChild())+1);
|
||||
@@ -480,13 +542,13 @@ class RegTree : public Model {
|
||||
}
|
||||
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
int NumExtraNodes() const {
|
||||
XGBOOST_NODISCARD int NumExtraNodes() const {
|
||||
return param.num_nodes - 1 - param.num_deleted;
|
||||
}
|
||||
|
||||
/* \brief Count number of leaves in tree. */
|
||||
bst_node_t GetNumLeaves() const;
|
||||
bst_node_t GetNumSplitNodes() const;
|
||||
XGBOOST_NODISCARD bst_node_t GetNumLeaves() const;
|
||||
XGBOOST_NODISCARD bst_node_t GetNumSplitNodes() const;
|
||||
|
||||
/*!
|
||||
* \brief dense feature vector that can be taken by RegTree
|
||||
@@ -513,20 +575,20 @@ class RegTree : public Model {
|
||||
* \brief returns the size of the feature vector
|
||||
* \return the size of the feature vector
|
||||
*/
|
||||
size_t Size() const;
|
||||
XGBOOST_NODISCARD size_t Size() const;
|
||||
/*!
|
||||
* \brief get ith value
|
||||
* \param i feature index.
|
||||
* \return the i-th feature value
|
||||
*/
|
||||
bst_float GetFvalue(size_t i) const;
|
||||
XGBOOST_NODISCARD bst_float GetFvalue(size_t i) const;
|
||||
/*!
|
||||
* \brief check whether i-th entry is missing
|
||||
* \param i feature index.
|
||||
* \return whether i-th value is missing.
|
||||
*/
|
||||
bool IsMissing(size_t i) const;
|
||||
bool HasMissing() const;
|
||||
XGBOOST_NODISCARD bool IsMissing(size_t i) const;
|
||||
XGBOOST_NODISCARD bool HasMissing() const;
|
||||
|
||||
|
||||
private:
|
||||
@@ -557,56 +619,123 @@ class RegTree : public Model {
|
||||
* \param format the format to dump the model in
|
||||
* \return the string of dumped model
|
||||
*/
|
||||
std::string DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const;
|
||||
XGBOOST_NODISCARD std::string DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) const;
|
||||
/*!
|
||||
* \brief Get split type for a node.
|
||||
* \param nidx Index of node.
|
||||
* \return The type of this split. For leaf node it's always kNumerical.
|
||||
*/
|
||||
FeatureType NodeSplitType(bst_node_t nidx) const {
|
||||
return split_types_.at(nidx);
|
||||
}
|
||||
XGBOOST_NODISCARD FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
|
||||
/*!
|
||||
* \brief Get split types for all nodes.
|
||||
*/
|
||||
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
|
||||
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
|
||||
XGBOOST_NODISCARD std::vector<FeatureType> const& GetSplitTypes() const {
|
||||
return split_types_;
|
||||
}
|
||||
XGBOOST_NODISCARD common::Span<uint32_t const> GetSplitCategories() const {
|
||||
return split_categories_;
|
||||
}
|
||||
/*!
|
||||
* \brief Get the bit storage for categories
|
||||
*/
|
||||
common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
|
||||
XGBOOST_NODISCARD common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
|
||||
auto node_ptr = GetCategoriesMatrix().node_ptr;
|
||||
auto categories = GetCategoriesMatrix().categories;
|
||||
auto segment = node_ptr[nidx];
|
||||
auto node_cats = categories.subspan(segment.beg, segment.size);
|
||||
return node_cats;
|
||||
}
|
||||
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||
|
||||
// The fields of split_categories_segments_[i] are set such that
|
||||
// the range split_categories_[beg:(beg+size)] stores the bitset for
|
||||
// the matching categories for the i-th node.
|
||||
struct Segment {
|
||||
size_t beg {0};
|
||||
size_t size {0};
|
||||
};
|
||||
XGBOOST_NODISCARD auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||
|
||||
/**
|
||||
* \brief CSR-like matrix for categorical splits.
|
||||
*
|
||||
* The fields of split_categories_segments_[i] are set such that the range
|
||||
* node_ptr[beg:(beg+size)] stores the bitset for the matching categories for the
|
||||
* i-th node.
|
||||
*/
|
||||
struct CategoricalSplitMatrix {
|
||||
struct Segment {
|
||||
std::size_t beg{0};
|
||||
std::size_t size{0};
|
||||
};
|
||||
common::Span<FeatureType const> split_type;
|
||||
common::Span<uint32_t const> categories;
|
||||
common::Span<Segment const> node_ptr;
|
||||
};
|
||||
|
||||
CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||
XGBOOST_NODISCARD CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||
CategoricalSplitMatrix view;
|
||||
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
|
||||
view.categories = this->GetSplitCategories();
|
||||
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
|
||||
view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
|
||||
return view;
|
||||
}
|
||||
|
||||
XGBOOST_NODISCARD bst_feature_t SplitIndex(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->SplitIndex(nidx);
|
||||
}
|
||||
return (*this)[nidx].SplitIndex();
|
||||
}
|
||||
XGBOOST_NODISCARD float SplitCond(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->SplitCond(nidx);
|
||||
}
|
||||
return (*this)[nidx].SplitCond();
|
||||
}
|
||||
XGBOOST_NODISCARD bool DefaultLeft(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->DefaultLeft(nidx);
|
||||
}
|
||||
return (*this)[nidx].DefaultLeft();
|
||||
}
|
||||
XGBOOST_NODISCARD bool IsRoot(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return nidx == kRoot;
|
||||
}
|
||||
return (*this)[nidx].IsRoot();
|
||||
}
|
||||
XGBOOST_NODISCARD bool IsLeaf(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->IsLeaf(nidx);
|
||||
}
|
||||
return (*this)[nidx].IsLeaf();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t Parent(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->Parent(nidx);
|
||||
}
|
||||
return (*this)[nidx].Parent();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t LeftChild(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->LeftChild(nidx);
|
||||
}
|
||||
return (*this)[nidx].LeftChild();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t RightChild(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->RightChild(nidx);
|
||||
}
|
||||
return (*this)[nidx].RightChild();
|
||||
}
|
||||
XGBOOST_NODISCARD bool IsLeftChild(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
CHECK_NE(nidx, kRoot);
|
||||
auto p = this->p_mt_tree_->Parent(nidx);
|
||||
return nidx == this->p_mt_tree_->LeftChild(p);
|
||||
}
|
||||
return (*this)[nidx].IsLeftChild();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t Size() const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->Size();
|
||||
}
|
||||
return this->nodes_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool typed>
|
||||
void LoadCategoricalSplit(Json const& in);
|
||||
@@ -622,8 +751,9 @@ class RegTree : public Model {
|
||||
// Categories for each internal node.
|
||||
std::vector<uint32_t> split_categories_;
|
||||
// Ptr to split categories of each node.
|
||||
std::vector<Segment> split_categories_segments_;
|
||||
|
||||
std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
|
||||
// ptr to multi-target tree with vector leaf.
|
||||
CopyUniquePtr<MultiTargetTree> p_mt_tree_;
|
||||
// allocate a new node,
|
||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||
bst_node_t AllocNode() {
|
||||
@@ -703,5 +833,10 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
|
||||
inline bool RegTree::FVec::HasMissing() const {
|
||||
return has_missing_;
|
||||
}
|
||||
|
||||
// Multi-target tree not yet implemented error
|
||||
inline StringView MTNotImplemented() {
|
||||
return " support for multi-target tree is not yet implemented.";
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_MODEL_H_
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* \file tree_updater.h
|
||||
* \brief General primitive for tree learning,
|
||||
* Updating a collection of trees given the information.
|
||||
@@ -9,19 +9,17 @@
|
||||
#define XGBOOST_TREE_UPDATER_H_
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/base.h> // for Args, GradientPair
|
||||
#include <xgboost/data.h> // DMatrix
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/linalg.h> // for VectorView
|
||||
#include <xgboost/model.h> // for Configurable
|
||||
#include <xgboost/span.h> // for Span
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional> // for function
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -30,8 +28,9 @@ struct TrainParam;
|
||||
|
||||
class Json;
|
||||
struct Context;
|
||||
struct ObjInfo;
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief interface of tree update module, that performs update of a tree.
|
||||
*/
|
||||
class TreeUpdater : public Configurable {
|
||||
@@ -53,12 +52,12 @@ class TreeUpdater : public Configurable {
|
||||
* used for modifying existing trees (like `prune`). Return true if it can modify
|
||||
* existing trees.
|
||||
*/
|
||||
virtual bool CanModifyTree() const { return false; }
|
||||
[[nodiscard]] virtual bool CanModifyTree() const { return false; }
|
||||
/*!
|
||||
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive
|
||||
* tree can be used.
|
||||
*/
|
||||
virtual bool HasNodePosition() const { return false; }
|
||||
[[nodiscard]] virtual bool HasNodePosition() const { return false; }
|
||||
/**
|
||||
* \brief perform update to the tree models
|
||||
*
|
||||
@@ -91,14 +90,15 @@ class TreeUpdater : public Configurable {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual char const* Name() const = 0;
|
||||
[[nodiscard]] virtual char const* Name() const = 0;
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief Create a tree updater given name
|
||||
* \param name Name of the tree updater.
|
||||
* \param ctx A global runtime parameter
|
||||
* \param task Infomation about the objective.
|
||||
*/
|
||||
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task);
|
||||
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -106,7 +106,7 @@ class TreeUpdater : public Configurable {
|
||||
*/
|
||||
struct TreeUpdaterReg
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo task)>> {};
|
||||
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register tree updater.
|
||||
|
||||
Reference in New Issue
Block a user