initial merge
This commit is contained in:
@@ -116,6 +116,18 @@ class DMatrixCache {
|
||||
* \param cache_size Maximum size of the cache.
|
||||
*/
|
||||
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
||||
|
||||
DMatrixCache& operator=(DMatrixCache&& that) {
|
||||
CHECK(lock_.try_lock());
|
||||
lock_.unlock();
|
||||
CHECK(that.lock_.try_lock());
|
||||
that.lock_.unlock();
|
||||
std::swap(this->container_, that.container_);
|
||||
std::swap(this->queue_, that.queue_);
|
||||
std::swap(this->max_size_, that.max_size_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Cache a new DMatrix if it's not in the cache already.
|
||||
*
|
||||
@@ -149,6 +161,26 @@ class DMatrixCache {
|
||||
}
|
||||
return container_.at(key).value;
|
||||
}
|
||||
/**
|
||||
* \brief Re-initialize the item in cache.
|
||||
*
|
||||
* Since the shared_ptr is used to hold the item, any reference that lives outside of
|
||||
* the cache can no-longer be reached from the cache.
|
||||
*
|
||||
* We use reset instead of erase to avoid walking through the whole cache for renewing
|
||||
* a single item. (the cache is FIFO, needs to maintain the order).
|
||||
*/
|
||||
template <typename... Args>
|
||||
std::shared_ptr<CacheT> ResetItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||
std::lock_guard<std::mutex> guard{lock_};
|
||||
CheckConsistent();
|
||||
auto key = Key{m.get(), std::this_thread::get_id()};
|
||||
auto it = container_.find(key);
|
||||
CHECK(it != container_.cend());
|
||||
it->second = {m, std::make_shared<CacheT>(args...)};
|
||||
CheckConsistent();
|
||||
return it->second.value;
|
||||
}
|
||||
/**
|
||||
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||
* returning.
|
||||
|
||||
@@ -171,6 +171,15 @@ class MetaInfo {
|
||||
*/
|
||||
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
||||
|
||||
/**
|
||||
* @brief Synchronize the number of columns across all workers.
|
||||
*
|
||||
* Normally we just need to find the maximum number of columns across all workers, but
|
||||
* in vertical federated learning, since each worker loads its own list of columns,
|
||||
* we need to sum them.
|
||||
*/
|
||||
void SynchronizeNumberOfColumns();
|
||||
|
||||
private:
|
||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||
@@ -325,6 +334,10 @@ class SparsePage {
|
||||
* \brief Check wether the column index is sorted.
|
||||
*/
|
||||
bool IsIndicesSorted(int32_t n_threads) const;
|
||||
/**
|
||||
* \brief Reindex the column index with an offset.
|
||||
*/
|
||||
void Reindex(uint64_t feature_offset, int32_t n_threads);
|
||||
|
||||
void SortRows(int32_t n_threads);
|
||||
|
||||
@@ -559,17 +572,18 @@ class DMatrix {
|
||||
* \brief Creates a new DMatrix from an external data adapter.
|
||||
*
|
||||
* \tparam AdapterT Type of the adapter.
|
||||
* \param [in,out] adapter View onto an external data.
|
||||
* \param missing Values to count as missing.
|
||||
* \param nthread Number of threads for construction.
|
||||
* \param cache_prefix (Optional) The cache prefix for external memory.
|
||||
* \param page_size (Optional) Size of the page.
|
||||
* \param [in,out] adapter View onto an external data.
|
||||
* \param missing Values to count as missing.
|
||||
* \param nthread Number of threads for construction.
|
||||
* \param cache_prefix (Optional) The cache prefix for external memory.
|
||||
* \param data_split_mode (Optional) Data split mode.
|
||||
*
|
||||
* \return a Created DMatrix.
|
||||
*/
|
||||
template <typename AdapterT>
|
||||
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix = "");
|
||||
const std::string& cache_prefix = "",
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow);
|
||||
|
||||
/**
|
||||
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#define XGBOOST_GBM_H_
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <dmlc/any.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) by Contributors 2019-2022
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_JSON_IO_H_
|
||||
#define XGBOOST_JSON_IO_H_
|
||||
@@ -17,44 +17,26 @@
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
namespace detail {
|
||||
// Whether char is signed is undefined, as a result we might or might not need
|
||||
// static_cast and std::to_string.
|
||||
template <typename Char, std::enable_if_t<std::is_signed<Char>::value>* = nullptr>
|
||||
std::string CharToStr(Char c) {
|
||||
static_assert(std::is_same<Char, char>::value);
|
||||
return std::string{c};
|
||||
}
|
||||
|
||||
template <typename Char, std::enable_if_t<!std::is_signed<Char>::value>* = nullptr>
|
||||
std::string CharToStr(Char c) {
|
||||
static_assert(std::is_same<Char, char>::value);
|
||||
return (c <= static_cast<char>(127) ? std::string{c} : std::to_string(c));
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/*
|
||||
/**
|
||||
* \brief A json reader, currently error checking and utf-8 is not fully supported.
|
||||
*/
|
||||
class JsonReader {
|
||||
public:
|
||||
using Char = std::int8_t;
|
||||
|
||||
protected:
|
||||
size_t constexpr static kMaxNumLength =
|
||||
std::numeric_limits<double>::max_digits10 + 1;
|
||||
size_t constexpr static kMaxNumLength = std::numeric_limits<double>::max_digits10 + 1;
|
||||
|
||||
struct SourceLocation {
|
||||
private:
|
||||
size_t pos_ { 0 }; // current position in raw_str_
|
||||
std::size_t pos_{0}; // current position in raw_str_
|
||||
|
||||
public:
|
||||
SourceLocation() = default;
|
||||
size_t Pos() const { return pos_; }
|
||||
size_t Pos() const { return pos_; }
|
||||
|
||||
void Forward() {
|
||||
pos_++;
|
||||
}
|
||||
void Forward(uint32_t n) {
|
||||
pos_ += n;
|
||||
}
|
||||
void Forward() { pos_++; }
|
||||
void Forward(uint32_t n) { pos_ += n; }
|
||||
} cursor_;
|
||||
|
||||
StringView raw_str_;
|
||||
@@ -62,7 +44,7 @@ class JsonReader {
|
||||
protected:
|
||||
void SkipSpaces();
|
||||
|
||||
char GetNextChar() {
|
||||
Char GetNextChar() {
|
||||
if (XGBOOST_EXPECT((cursor_.Pos() == raw_str_.size()), false)) {
|
||||
return -1;
|
||||
}
|
||||
@@ -71,24 +53,24 @@ class JsonReader {
|
||||
return ch;
|
||||
}
|
||||
|
||||
char PeekNextChar() {
|
||||
Char PeekNextChar() {
|
||||
if (cursor_.Pos() == raw_str_.size()) {
|
||||
return -1;
|
||||
}
|
||||
char ch = raw_str_[cursor_.Pos()];
|
||||
Char ch = raw_str_[cursor_.Pos()];
|
||||
return ch;
|
||||
}
|
||||
|
||||
/* \brief Skip spaces and consume next character. */
|
||||
char GetNextNonSpaceChar() {
|
||||
Char GetNextNonSpaceChar() {
|
||||
SkipSpaces();
|
||||
return GetNextChar();
|
||||
}
|
||||
/* \brief Consume next character without first skipping empty space, throw when the next
|
||||
* character is not the expected one.
|
||||
*/
|
||||
char GetConsecutiveChar(char expected_char) {
|
||||
char result = GetNextChar();
|
||||
Char GetConsecutiveChar(char expected_char) {
|
||||
Char result = GetNextChar();
|
||||
if (XGBOOST_EXPECT(result != expected_char, false)) { Expect(expected_char, result); }
|
||||
return result;
|
||||
}
|
||||
@@ -96,7 +78,7 @@ class JsonReader {
|
||||
void Error(std::string msg) const;
|
||||
|
||||
// Report expected character
|
||||
void Expect(char c, char got) {
|
||||
void Expect(Char c, Char got) {
|
||||
std::string msg = "Expecting: \"";
|
||||
msg += c;
|
||||
msg += "\", got: \"";
|
||||
@@ -105,7 +87,7 @@ class JsonReader {
|
||||
} else if (got == 0) {
|
||||
msg += "\\0\"";
|
||||
} else {
|
||||
msg += detail::CharToStr(got) + " \"";
|
||||
msg += std::to_string(got) + " \"";
|
||||
}
|
||||
Error(msg);
|
||||
}
|
||||
|
||||
@@ -286,8 +286,8 @@ struct LearnerModelParamLegacy;
|
||||
* \brief Strategy for building multi-target models.
|
||||
*/
|
||||
enum class MultiStrategy : std::int32_t {
|
||||
kComposite = 0,
|
||||
kMonolithic = 1,
|
||||
kOneOutputPerTree = 0,
|
||||
kMultiOutputTree = 1,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -317,7 +317,7 @@ struct LearnerModelParam {
|
||||
/**
|
||||
* \brief Strategy for building multi-target models.
|
||||
*/
|
||||
MultiStrategy multi_strategy{MultiStrategy::kComposite};
|
||||
MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree};
|
||||
|
||||
LearnerModelParam() = default;
|
||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||
@@ -338,7 +338,7 @@ struct LearnerModelParam {
|
||||
|
||||
void Copy(LearnerModelParam const& that);
|
||||
[[nodiscard]] bool IsVectorLeaf() const noexcept {
|
||||
return multi_strategy == MultiStrategy::kMonolithic;
|
||||
return multi_strategy == MultiStrategy::kMultiOutputTree;
|
||||
}
|
||||
[[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; }
|
||||
[[nodiscard]] bst_target_t LeafLength() const noexcept {
|
||||
|
||||
@@ -30,11 +30,11 @@
|
||||
|
||||
// decouple it from xgboost.
|
||||
#ifndef LINALG_HD
|
||||
#if defined(__CUDA__) || defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__CUDA__) || defined(__NVCC__)
|
||||
#define LINALG_HD __host__ __device__
|
||||
#else
|
||||
#define LINALG_HD
|
||||
#endif // defined (__CUDA__) || defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__)
|
||||
#endif // defined (__CUDA__) || defined(__NVCC__)
|
||||
#endif // LINALG_HD
|
||||
|
||||
namespace xgboost::linalg {
|
||||
@@ -118,9 +118,9 @@ using IndexToTag = std::conditional_t<std::is_integral<RemoveCRType<S>>::value,
|
||||
|
||||
template <int32_t n, typename Fn>
|
||||
LINALG_HD constexpr auto UnrollLoop(Fn fn) {
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined __CUDA_ARCH__
|
||||
#pragma unroll n
|
||||
#endif // defined __CUDA_ARCH__ || defined(__HIP_PLATFORM_AMD__)
|
||||
#endif // defined __CUDA_ARCH__
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
fn(i);
|
||||
}
|
||||
@@ -136,7 +136,7 @@ int32_t NativePopc(T v) {
|
||||
inline LINALG_HD int Popc(uint32_t v) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return __popc(v);
|
||||
#elif defined(__GNUC__) || defined(__clang__) || defined(__HIP_PLATFORM_AMD__)
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
return __builtin_popcount(v);
|
||||
#elif defined(_MSC_VER)
|
||||
return __popcnt(v);
|
||||
@@ -148,7 +148,7 @@ inline LINALG_HD int Popc(uint32_t v) {
|
||||
inline LINALG_HD int Popc(uint64_t v) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return __popcll(v);
|
||||
#elif defined(__GNUC__) || defined(__clang__) || defined(__HIP_PLATFORM_AMD__)
|
||||
#elif defined(__GNUC__) || defined(__clang__)
|
||||
return __builtin_popcountll(v);
|
||||
#elif defined(_MSC_VER) && _defined(_M_X64)
|
||||
return __popcnt64(v);
|
||||
@@ -530,17 +530,17 @@ class TensorView {
|
||||
/**
|
||||
* \brief Number of items in the tensor.
|
||||
*/
|
||||
LINALG_HD std::size_t Size() const { return size_; }
|
||||
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
|
||||
/**
|
||||
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
||||
*/
|
||||
LINALG_HD bool Contiguous() const {
|
||||
[[nodiscard]] LINALG_HD bool Contiguous() const {
|
||||
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
|
||||
}
|
||||
/**
|
||||
* \brief Whether it's a c-contiguous array.
|
||||
*/
|
||||
LINALG_HD bool CContiguous() const {
|
||||
[[nodiscard]] LINALG_HD bool CContiguous() const {
|
||||
StrideT stride;
|
||||
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
||||
// It's contiguous if the stride can be calculated from shape.
|
||||
@@ -550,7 +550,7 @@ class TensorView {
|
||||
/**
|
||||
* \brief Whether it's a f-contiguous array.
|
||||
*/
|
||||
LINALG_HD bool FContiguous() const {
|
||||
[[nodiscard]] LINALG_HD bool FContiguous() const {
|
||||
StrideT stride;
|
||||
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
||||
// It's contiguous if the stride can be calculated from shape.
|
||||
|
||||
@@ -29,11 +29,6 @@
|
||||
namespace xgboost {
|
||||
class Json;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#define XGBOOST_NODISCARD
|
||||
#else
|
||||
#define XGBOOST_NODISCARD [[nodiscard]]
|
||||
#endif
|
||||
// FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
|
||||
// not be configured by users.
|
||||
/*! \brief meta parameters of the tree */
|
||||
@@ -64,7 +59,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
|
||||
|
||||
// Swap byte order for all fields. Useful for transporting models between machines with different
|
||||
// endianness (big endian vs little endian)
|
||||
XGBOOST_NODISCARD TreeParam ByteSwap() const {
|
||||
[[nodiscard]] TreeParam ByteSwap() const {
|
||||
TreeParam x = *this;
|
||||
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
|
||||
dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
|
||||
@@ -117,7 +112,7 @@ struct RTreeNodeStat {
|
||||
}
|
||||
// Swap byte order for all fields. Useful for transporting models between machines with different
|
||||
// endianness (big endian vs little endian)
|
||||
XGBOOST_NODISCARD RTreeNodeStat ByteSwap() const {
|
||||
[[nodiscard]] RTreeNodeStat ByteSwap() const {
|
||||
RTreeNodeStat x = *this;
|
||||
dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
|
||||
dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
|
||||
@@ -183,51 +178,33 @@ class RegTree : public Model {
|
||||
}
|
||||
|
||||
/*! \brief index of left child */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int LeftChild() const {
|
||||
return this->cleft_;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
|
||||
/*! \brief index of right child */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int RightChild() const {
|
||||
return this->cright_;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
|
||||
/*! \brief index of default child when feature is missing */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int DefaultChild() const {
|
||||
[[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
|
||||
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
|
||||
}
|
||||
/*! \brief feature index of split condition */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD unsigned SplitIndex() const {
|
||||
[[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
|
||||
return sindex_ & ((1U << 31) - 1U);
|
||||
}
|
||||
/*! \brief when feature is unknown, whether goes to left child */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool DefaultLeft() const {
|
||||
return (sindex_ >> 31) != 0;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
|
||||
/*! \brief whether current node is leaf node */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsLeaf() const {
|
||||
return cleft_ == kInvalidNodeId;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
|
||||
/*! \return get leaf value of leaf node */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD float LeafValue() const {
|
||||
return (this->info_).leaf_value;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
|
||||
/*! \return get split condition of the node */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD SplitCondT SplitCond() const {
|
||||
return (this->info_).split_cond;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
|
||||
/*! \brief get parent of the node */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD int Parent() const {
|
||||
return parent_ & ((1U << 31) - 1);
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
|
||||
/*! \brief whether current node is left child */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsLeftChild() const {
|
||||
return (parent_ & (1U << 31)) != 0;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
|
||||
/*! \brief whether this node is deleted */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsDeleted() const {
|
||||
return sindex_ == kDeletedNodeMarker;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
|
||||
/*! \brief whether current node is root */
|
||||
XGBOOST_DEVICE XGBOOST_NODISCARD bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
[[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
|
||||
/*!
|
||||
* \brief set the left child
|
||||
* \param nid node id to right child
|
||||
@@ -284,7 +261,7 @@ class RegTree : public Model {
|
||||
info_.leaf_value == b.info_.leaf_value;
|
||||
}
|
||||
|
||||
XGBOOST_NODISCARD Node ByteSwap() const {
|
||||
[[nodiscard]] Node ByteSwap() const {
|
||||
Node x = *this;
|
||||
dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
|
||||
dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
|
||||
@@ -342,15 +319,13 @@ class RegTree : public Model {
|
||||
this->ChangeToLeaf(rid, value);
|
||||
}
|
||||
|
||||
/*! \brief model parameter */
|
||||
TreeParam param;
|
||||
RegTree() {
|
||||
param.Init(Args{});
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
for (int i = 0; i < param.num_nodes; i++) {
|
||||
param_.Init(Args{});
|
||||
nodes_.resize(param_.num_nodes);
|
||||
stats_.resize(param_.num_nodes);
|
||||
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param_.num_nodes);
|
||||
for (int i = 0; i < param_.num_nodes; i++) {
|
||||
nodes_[i].SetLeaf(0.0f);
|
||||
nodes_[i].SetParent(kInvalidNodeId);
|
||||
}
|
||||
@@ -359,10 +334,10 @@ class RegTree : public Model {
|
||||
* \brief Constructor that initializes the tree model with shape.
|
||||
*/
|
||||
explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
|
||||
param.num_feature = n_features;
|
||||
param.size_leaf_vector = n_targets;
|
||||
param_.num_feature = n_features;
|
||||
param_.size_leaf_vector = n_targets;
|
||||
if (n_targets > 1) {
|
||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m});
|
||||
this->p_mt_tree_.reset(new MultiTargetTree{¶m_});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,17 +351,17 @@ class RegTree : public Model {
|
||||
}
|
||||
|
||||
/*! \brief get const reference to nodes */
|
||||
XGBOOST_NODISCARD const std::vector<Node>& GetNodes() const { return nodes_; }
|
||||
[[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
|
||||
|
||||
/*! \brief get const reference to stats */
|
||||
XGBOOST_NODISCARD const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
|
||||
[[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
|
||||
|
||||
/*! \brief get node statistics given nid */
|
||||
RTreeNodeStat& Stat(int nid) {
|
||||
return stats_[nid];
|
||||
}
|
||||
/*! \brief get node statistics given nid */
|
||||
XGBOOST_NODISCARD const RTreeNodeStat& Stat(int nid) const {
|
||||
[[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
|
||||
return stats_[nid];
|
||||
}
|
||||
|
||||
@@ -406,7 +381,7 @@ class RegTree : public Model {
|
||||
|
||||
bool operator==(const RegTree& b) const {
|
||||
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
|
||||
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
|
||||
deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
|
||||
}
|
||||
/* \brief Iterate through all nodes in this tree.
|
||||
*
|
||||
@@ -439,7 +414,7 @@ class RegTree : public Model {
|
||||
*
|
||||
* \param b The other tree.
|
||||
*/
|
||||
XGBOOST_NODISCARD bool Equal(const RegTree& b) const;
|
||||
[[nodiscard]] bool Equal(const RegTree& b) const;
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes.
|
||||
@@ -464,7 +439,9 @@ class RegTree : public Model {
|
||||
bst_float loss_change, float sum_hess, float left_sum,
|
||||
float right_sum,
|
||||
bst_node_t leaf_right_child = kInvalidNodeId);
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
|
||||
*/
|
||||
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
|
||||
linalg::VectorView<float const> base_weight,
|
||||
linalg::VectorView<float const> left_weight,
|
||||
@@ -490,25 +467,54 @@ class RegTree : public Model {
|
||||
bst_float base_weight, bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||
float left_sum, float right_sum);
|
||||
|
||||
XGBOOST_NODISCARD bool HasCategoricalSplit() const {
|
||||
return !split_categories_.empty();
|
||||
}
|
||||
/**
|
||||
* \brief Whether this tree has categorical split.
|
||||
*/
|
||||
[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
|
||||
/**
|
||||
* \brief Whether this is a multi-target tree.
|
||||
*/
|
||||
XGBOOST_NODISCARD bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
|
||||
XGBOOST_NODISCARD bst_target_t NumTargets() const { return param.size_leaf_vector; }
|
||||
XGBOOST_NODISCARD auto GetMultiTargetTree() const {
|
||||
[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
|
||||
/**
|
||||
* \brief The size of leaf weight.
|
||||
*/
|
||||
[[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
|
||||
/**
|
||||
* \brief Get the underlying implementaiton of multi-target tree.
|
||||
*/
|
||||
[[nodiscard]] auto GetMultiTargetTree() const {
|
||||
CHECK(IsMultiTarget());
|
||||
return p_mt_tree_.get();
|
||||
}
|
||||
/**
|
||||
* \brief Get the number of features.
|
||||
*/
|
||||
[[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
|
||||
/**
|
||||
* \brief Get the total number of nodes including deleted ones in this tree.
|
||||
*/
|
||||
[[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
|
||||
/**
|
||||
* \brief Get the total number of valid nodes in this tree.
|
||||
*/
|
||||
[[nodiscard]] bst_node_t NumValidNodes() const noexcept {
|
||||
return param_.num_nodes - param_.num_deleted;
|
||||
}
|
||||
/**
|
||||
* \brief number of extra nodes besides the root
|
||||
*/
|
||||
[[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
|
||||
return param_.num_nodes - 1 - param_.num_deleted;
|
||||
}
|
||||
/* \brief Count number of leaves in tree. */
|
||||
[[nodiscard]] bst_node_t GetNumLeaves() const;
|
||||
[[nodiscard]] bst_node_t GetNumSplitNodes() const;
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
*/
|
||||
XGBOOST_NODISCARD std::int32_t GetDepth(bst_node_t nid) const {
|
||||
[[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->Depth(nid);
|
||||
}
|
||||
@@ -519,6 +525,9 @@ class RegTree : public Model {
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
/**
|
||||
* \brief Set the leaf weight for a multi-target tree.
|
||||
*/
|
||||
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
|
||||
CHECK(IsMultiTarget());
|
||||
return this->p_mt_tree_->SetLeaf(nidx, weight);
|
||||
@@ -528,27 +537,15 @@ class RegTree : public Model {
|
||||
* \brief get maximum depth
|
||||
* \param nid node id
|
||||
*/
|
||||
XGBOOST_NODISCARD int MaxDepth(int nid) const {
|
||||
[[nodiscard]] int MaxDepth(int nid) const {
|
||||
if (nodes_[nid].IsLeaf()) return 0;
|
||||
return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
|
||||
MaxDepth(nodes_[nid].RightChild())+1);
|
||||
return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
*/
|
||||
int MaxDepth() {
|
||||
return MaxDepth(0);
|
||||
}
|
||||
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
XGBOOST_NODISCARD int NumExtraNodes() const {
|
||||
return param.num_nodes - 1 - param.num_deleted;
|
||||
}
|
||||
|
||||
/* \brief Count number of leaves in tree. */
|
||||
XGBOOST_NODISCARD bst_node_t GetNumLeaves() const;
|
||||
XGBOOST_NODISCARD bst_node_t GetNumSplitNodes() const;
|
||||
int MaxDepth() { return MaxDepth(0); }
|
||||
|
||||
/*!
|
||||
* \brief dense feature vector that can be taken by RegTree
|
||||
@@ -575,20 +572,20 @@ class RegTree : public Model {
|
||||
* \brief returns the size of the feature vector
|
||||
* \return the size of the feature vector
|
||||
*/
|
||||
XGBOOST_NODISCARD size_t Size() const;
|
||||
[[nodiscard]] size_t Size() const;
|
||||
/*!
|
||||
* \brief get ith value
|
||||
* \param i feature index.
|
||||
* \return the i-th feature value
|
||||
*/
|
||||
XGBOOST_NODISCARD bst_float GetFvalue(size_t i) const;
|
||||
[[nodiscard]] bst_float GetFvalue(size_t i) const;
|
||||
/*!
|
||||
* \brief check whether i-th entry is missing
|
||||
* \param i feature index.
|
||||
* \return whether i-th value is missing.
|
||||
*/
|
||||
XGBOOST_NODISCARD bool IsMissing(size_t i) const;
|
||||
XGBOOST_NODISCARD bool HasMissing() const;
|
||||
[[nodiscard]] bool IsMissing(size_t i) const;
|
||||
[[nodiscard]] bool HasMissing() const;
|
||||
|
||||
|
||||
private:
|
||||
@@ -619,34 +616,34 @@ class RegTree : public Model {
|
||||
* \param format the format to dump the model in
|
||||
* \return the string of dumped model
|
||||
*/
|
||||
XGBOOST_NODISCARD std::string DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
[[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) const;
|
||||
/*!
|
||||
* \brief Get split type for a node.
|
||||
* \param nidx Index of node.
|
||||
* \return The type of this split. For leaf node it's always kNumerical.
|
||||
*/
|
||||
XGBOOST_NODISCARD FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
|
||||
[[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
|
||||
/*!
|
||||
* \brief Get split types for all nodes.
|
||||
*/
|
||||
XGBOOST_NODISCARD std::vector<FeatureType> const& GetSplitTypes() const {
|
||||
[[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
|
||||
return split_types_;
|
||||
}
|
||||
XGBOOST_NODISCARD common::Span<uint32_t const> GetSplitCategories() const {
|
||||
[[nodiscard]] common::Span<uint32_t const> GetSplitCategories() const {
|
||||
return split_categories_;
|
||||
}
|
||||
/*!
|
||||
* \brief Get the bit storage for categories
|
||||
*/
|
||||
XGBOOST_NODISCARD common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
|
||||
[[nodiscard]] common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
|
||||
auto node_ptr = GetCategoriesMatrix().node_ptr;
|
||||
auto categories = GetCategoriesMatrix().categories;
|
||||
auto segment = node_ptr[nidx];
|
||||
auto node_cats = categories.subspan(segment.beg, segment.size);
|
||||
return node_cats;
|
||||
}
|
||||
XGBOOST_NODISCARD auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||
[[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||
|
||||
/**
|
||||
* \brief CSR-like matrix for categorical splits.
|
||||
@@ -665,7 +662,7 @@ class RegTree : public Model {
|
||||
common::Span<Segment const> node_ptr;
|
||||
};
|
||||
|
||||
XGBOOST_NODISCARD CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||
[[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||
CategoricalSplitMatrix view;
|
||||
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
|
||||
view.categories = this->GetSplitCategories();
|
||||
@@ -673,55 +670,55 @@ class RegTree : public Model {
|
||||
return view;
|
||||
}
|
||||
|
||||
XGBOOST_NODISCARD bst_feature_t SplitIndex(bst_node_t nidx) const {
|
||||
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->SplitIndex(nidx);
|
||||
}
|
||||
return (*this)[nidx].SplitIndex();
|
||||
}
|
||||
XGBOOST_NODISCARD float SplitCond(bst_node_t nidx) const {
|
||||
[[nodiscard]] float SplitCond(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->SplitCond(nidx);
|
||||
}
|
||||
return (*this)[nidx].SplitCond();
|
||||
}
|
||||
XGBOOST_NODISCARD bool DefaultLeft(bst_node_t nidx) const {
|
||||
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->DefaultLeft(nidx);
|
||||
}
|
||||
return (*this)[nidx].DefaultLeft();
|
||||
}
|
||||
XGBOOST_NODISCARD bool IsRoot(bst_node_t nidx) const {
|
||||
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return nidx == kRoot;
|
||||
}
|
||||
return (*this)[nidx].IsRoot();
|
||||
}
|
||||
XGBOOST_NODISCARD bool IsLeaf(bst_node_t nidx) const {
|
||||
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->IsLeaf(nidx);
|
||||
}
|
||||
return (*this)[nidx].IsLeaf();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t Parent(bst_node_t nidx) const {
|
||||
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->Parent(nidx);
|
||||
}
|
||||
return (*this)[nidx].Parent();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t LeftChild(bst_node_t nidx) const {
|
||||
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->LeftChild(nidx);
|
||||
}
|
||||
return (*this)[nidx].LeftChild();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t RightChild(bst_node_t nidx) const {
|
||||
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->RightChild(nidx);
|
||||
}
|
||||
return (*this)[nidx].RightChild();
|
||||
}
|
||||
XGBOOST_NODISCARD bool IsLeftChild(bst_node_t nidx) const {
|
||||
[[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
CHECK_NE(nidx, kRoot);
|
||||
auto p = this->p_mt_tree_->Parent(nidx);
|
||||
@@ -729,7 +726,7 @@ class RegTree : public Model {
|
||||
}
|
||||
return (*this)[nidx].IsLeftChild();
|
||||
}
|
||||
XGBOOST_NODISCARD bst_node_t Size() const {
|
||||
[[nodiscard]] bst_node_t Size() const {
|
||||
if (IsMultiTarget()) {
|
||||
return this->p_mt_tree_->Size();
|
||||
}
|
||||
@@ -740,6 +737,8 @@ class RegTree : public Model {
|
||||
template <bool typed>
|
||||
void LoadCategoricalSplit(Json const& in);
|
||||
void SaveCategoricalSplit(Json* p_out) const;
|
||||
/*! \brief model parameter */
|
||||
TreeParam param_;
|
||||
// vector of nodes
|
||||
std::vector<Node> nodes_;
|
||||
// free node space, used during training process
|
||||
@@ -757,20 +756,20 @@ class RegTree : public Model {
|
||||
// allocate a new node,
|
||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||
bst_node_t AllocNode() {
|
||||
if (param.num_deleted != 0) {
|
||||
if (param_.num_deleted != 0) {
|
||||
int nid = deleted_nodes_.back();
|
||||
deleted_nodes_.pop_back();
|
||||
nodes_[nid].Reuse();
|
||||
--param.num_deleted;
|
||||
--param_.num_deleted;
|
||||
return nid;
|
||||
}
|
||||
int nd = param.num_nodes++;
|
||||
CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
|
||||
int nd = param_.num_nodes++;
|
||||
CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
|
||||
<< "number of nodes in the tree exceed 2^31";
|
||||
nodes_.resize(param.num_nodes);
|
||||
stats_.resize(param.num_nodes);
|
||||
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param.num_nodes);
|
||||
nodes_.resize(param_.num_nodes);
|
||||
stats_.resize(param_.num_nodes);
|
||||
split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
|
||||
split_categories_segments_.resize(param_.num_nodes);
|
||||
return nd;
|
||||
}
|
||||
// delete a tree node, keep the parent field to allow trace back
|
||||
@@ -785,7 +784,7 @@ class RegTree : public Model {
|
||||
|
||||
deleted_nodes_.push_back(nid);
|
||||
nodes_[nid].MarkDelete();
|
||||
++param.num_deleted;
|
||||
++param_.num_deleted;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user