Expand categorical node. (#6028)
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
9a4e8b1d81
commit
20c95be625
@ -245,11 +245,12 @@ test_that("training continuation works", {
|
|||||||
expect_equal(bst$raw, bst2$raw)
|
expect_equal(bst$raw, bst2$raw)
|
||||||
expect_equal(dim(bst2$evaluation_log), c(2, 2))
|
expect_equal(dim(bst2$evaluation_log), c(2, 2))
|
||||||
# test continuing from a model in file
|
# test continuing from a model in file
|
||||||
xgb.save(bst1, "xgboost.model")
|
xgb.save(bst1, "xgboost.json")
|
||||||
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.model")
|
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.json")
|
||||||
if (!windows_flag && !solaris_flag)
|
if (!windows_flag && !solaris_flag)
|
||||||
expect_equal(bst$raw, bst2$raw)
|
expect_equal(bst$raw, bst2$raw)
|
||||||
expect_equal(dim(bst2$evaluation_log), c(2, 2))
|
expect_equal(dim(bst2$evaluation_log), c(2, 2))
|
||||||
|
file.remove("xgboost.json")
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("model serialization works", {
|
test_that("model serialization works", {
|
||||||
|
|||||||
@ -173,16 +173,16 @@ test_that("cb.reset.parameters works as expected", {
|
|||||||
})
|
})
|
||||||
|
|
||||||
test_that("cb.save.model works as expected", {
|
test_that("cb.save.model works as expected", {
|
||||||
files <- c('xgboost_01.model', 'xgboost_02.model', 'xgboost.model')
|
files <- c('xgboost_01.json', 'xgboost_02.json', 'xgboost.json')
|
||||||
for (f in files) if (file.exists(f)) file.remove(f)
|
for (f in files) if (file.exists(f)) file.remove(f)
|
||||||
|
|
||||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
|
||||||
save_period = 1, save_name = "xgboost_%02d.model")
|
save_period = 1, save_name = "xgboost_%02d.json")
|
||||||
expect_true(file.exists('xgboost_01.model'))
|
expect_true(file.exists('xgboost_01.json'))
|
||||||
expect_true(file.exists('xgboost_02.model'))
|
expect_true(file.exists('xgboost_02.json'))
|
||||||
b1 <- xgb.load('xgboost_01.model')
|
b1 <- xgb.load('xgboost_01.json')
|
||||||
expect_equal(xgb.ntree(b1), 1)
|
expect_equal(xgb.ntree(b1), 1)
|
||||||
b2 <- xgb.load('xgboost_02.model')
|
b2 <- xgb.load('xgboost_02.json')
|
||||||
expect_equal(xgb.ntree(b2), 2)
|
expect_equal(xgb.ntree(b2), 2)
|
||||||
|
|
||||||
xgb.config(b2) <- xgb.config(bst)
|
xgb.config(b2) <- xgb.config(bst)
|
||||||
@ -191,9 +191,9 @@ test_that("cb.save.model works as expected", {
|
|||||||
|
|
||||||
# save_period = 0 saves the last iteration's model
|
# save_period = 0 saves the last iteration's model
|
||||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
|
||||||
save_period = 0)
|
save_period = 0, save_name = 'xgboost.json')
|
||||||
expect_true(file.exists('xgboost.model'))
|
expect_true(file.exists('xgboost.json'))
|
||||||
b2 <- xgb.load('xgboost.model')
|
b2 <- xgb.load('xgboost.json')
|
||||||
xgb.config(b2) <- xgb.config(bst)
|
xgb.config(b2) <- xgb.config(bst)
|
||||||
expect_equal(bst$raw, b2$raw)
|
expect_equal(bst$raw, b2$raw)
|
||||||
|
|
||||||
|
|||||||
@ -109,7 +109,8 @@ using bst_int = int32_t; // NOLINT
|
|||||||
using bst_ulong = uint64_t; // NOLINT
|
using bst_ulong = uint64_t; // NOLINT
|
||||||
/*! \brief float type, used for storing statistics */
|
/*! \brief float type, used for storing statistics */
|
||||||
using bst_float = float; // NOLINT
|
using bst_float = float; // NOLINT
|
||||||
|
/*! \brief Categorical value type. */
|
||||||
|
using bst_cat_t = int32_t; // NOLINT
|
||||||
/*! \brief Type for data column (feature) index. */
|
/*! \brief Type for data column (feature) index. */
|
||||||
using bst_feature_t = uint32_t; // NOLINT
|
using bst_feature_t = uint32_t; // NOLINT
|
||||||
/*! \brief Type for data row index.
|
/*! \brief Type for data row index.
|
||||||
|
|||||||
@ -35,7 +35,8 @@ enum class DataType : uint8_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum class FeatureType : uint8_t {
|
enum class FeatureType : uint8_t {
|
||||||
kNumerical
|
kNumerical,
|
||||||
|
kCategorical
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -309,12 +310,6 @@ class SparsePage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief Push row block into the page.
|
|
||||||
* \param batch the row batch.
|
|
||||||
*/
|
|
||||||
void Push(const dmlc::RowBlock<uint32_t>& batch);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Pushes external data batch onto this page
|
* \brief Pushes external data batch onto this page
|
||||||
*
|
*
|
||||||
|
|||||||
@ -101,6 +101,18 @@ namespace common {
|
|||||||
} while (0);
|
} while (0);
|
||||||
#endif // __CUDA_ARCH__
|
#endif // __CUDA_ARCH__
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__)
|
||||||
|
#define SPAN_LT(lhs, rhs) \
|
||||||
|
if (!((lhs) < (rhs))) { \
|
||||||
|
printf("%lu < %lu failed\n", static_cast<size_t>(lhs), \
|
||||||
|
static_cast<size_t>(rhs)); \
|
||||||
|
asm("trap;"); \
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#define SPAN_LT(lhs, rhs) \
|
||||||
|
SPAN_CHECK((lhs) < (rhs))
|
||||||
|
#endif // defined(__CUDA_ARCH__)
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
/*!
|
/*!
|
||||||
* By default, XGBoost uses uint32_t for indexing data. int64_t covers all
|
* By default, XGBoost uses uint32_t for indexing data. int64_t covers all
|
||||||
@ -515,7 +527,7 @@ class Span {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE reference operator[](index_type _idx) const {
|
XGBOOST_DEVICE reference operator[](index_type _idx) const {
|
||||||
SPAN_CHECK(_idx < size());
|
SPAN_LT(_idx, size());
|
||||||
return data()[_idx];
|
return data()[_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -575,7 +587,6 @@ class Span {
|
|||||||
detail::ExtentValue<Extent, Offset, Count>::value> {
|
detail::ExtentValue<Extent, Offset, Count>::value> {
|
||||||
SPAN_CHECK((Count == dynamic_extent) ?
|
SPAN_CHECK((Count == dynamic_extent) ?
|
||||||
(Offset <= size()) : (Offset + Count <= size()));
|
(Offset <= size()) : (Offset + Count <= size()));
|
||||||
|
|
||||||
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -318,6 +318,8 @@ class RegTree : public Model {
|
|||||||
param.num_deleted = 0;
|
param.num_deleted = 0;
|
||||||
nodes_.resize(param.num_nodes);
|
nodes_.resize(param.num_nodes);
|
||||||
stats_.resize(param.num_nodes);
|
stats_.resize(param.num_nodes);
|
||||||
|
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||||
|
split_categories_segments_.resize(param.num_nodes);
|
||||||
for (int i = 0; i < param.num_nodes; i ++) {
|
for (int i = 0; i < param.num_nodes; i ++) {
|
||||||
nodes_[i].SetLeaf(0.0f);
|
nodes_[i].SetLeaf(0.0f);
|
||||||
nodes_[i].SetParent(kInvalidNodeId);
|
nodes_[i].SetParent(kInvalidNodeId);
|
||||||
@ -412,30 +414,33 @@ class RegTree : public Model {
|
|||||||
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
|
||||||
* some updaters use the right child index of leaf as a marker
|
* some updaters use the right child index of leaf as a marker
|
||||||
*/
|
*/
|
||||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
|
||||||
bool default_left, bst_float base_weight,
|
bool default_left, bst_float base_weight,
|
||||||
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||||
bst_float loss_change, float sum_hess, float left_sum,
|
bst_float loss_change, float sum_hess, float left_sum,
|
||||||
float right_sum,
|
float right_sum,
|
||||||
bst_node_t leaf_right_child = kInvalidNodeId) {
|
bst_node_t leaf_right_child = kInvalidNodeId);
|
||||||
int pleft = this->AllocNode();
|
|
||||||
int pright = this->AllocNode();
|
|
||||||
auto &node = nodes_[nid];
|
|
||||||
CHECK(node.IsLeaf());
|
|
||||||
node.SetLeftChild(pleft);
|
|
||||||
node.SetRightChild(pright);
|
|
||||||
nodes_[node.LeftChild()].SetParent(nid, true);
|
|
||||||
nodes_[node.RightChild()].SetParent(nid, false);
|
|
||||||
node.SetSplit(split_index, split_value,
|
|
||||||
default_left);
|
|
||||||
|
|
||||||
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
/**
|
||||||
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
* \brief Expands a leaf node with categories
|
||||||
|
*
|
||||||
this->Stat(nid) = {loss_change, sum_hess, base_weight};
|
* \param nid The node index to expand.
|
||||||
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
|
* \param split_index Feature index of the split.
|
||||||
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
|
* \param split_cat The bitset containing categories
|
||||||
}
|
* \param default_left True to default left.
|
||||||
|
* \param base_weight The base weight, before learning rate.
|
||||||
|
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
|
||||||
|
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||||
|
* \param loss_change The loss change.
|
||||||
|
* \param sum_hess The sum hess.
|
||||||
|
* \param left_sum The sum hess of left leaf.
|
||||||
|
* \param right_sum The sum hess of right leaf.
|
||||||
|
*/
|
||||||
|
void ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||||
|
common::Span<uint32_t> split_cat, bool default_left,
|
||||||
|
bst_float base_weight, bst_float left_leaf_weight,
|
||||||
|
bst_float right_leaf_weight, bst_float loss_change,
|
||||||
|
float sum_hess, float left_sum, float right_sum);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get current depth
|
* \brief get current depth
|
||||||
@ -588,6 +593,28 @@ class RegTree : public Model {
|
|||||||
* \brief calculate the mean value for each node, required for feature contributions
|
* \brief calculate the mean value for each node, required for feature contributions
|
||||||
*/
|
*/
|
||||||
void FillNodeMeanValues();
|
void FillNodeMeanValues();
|
||||||
|
/*!
|
||||||
|
* \brief Get split type for a node.
|
||||||
|
* \param nidx Index of node.
|
||||||
|
* \return The type of this split. For leaf node it's always kNumerical.
|
||||||
|
*/
|
||||||
|
FeatureType NodeSplitType(bst_node_t nidx) const {
|
||||||
|
return split_types_.at(nidx);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief Get split types for all nodes.
|
||||||
|
*/
|
||||||
|
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
|
||||||
|
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
|
||||||
|
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
|
||||||
|
|
||||||
|
// The fields of split_categories_segments_[i] are set such that
|
||||||
|
// the range split_categories_[beg:(beg+size)] stores the bitset for
|
||||||
|
// the matching categories for the i-th node.
|
||||||
|
struct Segment {
|
||||||
|
size_t beg {0};
|
||||||
|
size_t size {0};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// vector of nodes
|
// vector of nodes
|
||||||
@ -597,9 +624,16 @@ class RegTree : public Model {
|
|||||||
// stats of nodes
|
// stats of nodes
|
||||||
std::vector<RTreeNodeStat> stats_;
|
std::vector<RTreeNodeStat> stats_;
|
||||||
std::vector<bst_float> node_mean_values_;
|
std::vector<bst_float> node_mean_values_;
|
||||||
|
std::vector<FeatureType> split_types_;
|
||||||
|
|
||||||
|
// Categories for each internal node.
|
||||||
|
std::vector<uint32_t> split_categories_;
|
||||||
|
// Ptr to split categories of each node.
|
||||||
|
std::vector<Segment> split_categories_segments_;
|
||||||
|
|
||||||
// allocate a new node,
|
// allocate a new node,
|
||||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||||
int AllocNode() {
|
bst_node_t AllocNode() {
|
||||||
if (param.num_deleted != 0) {
|
if (param.num_deleted != 0) {
|
||||||
int nid = deleted_nodes_.back();
|
int nid = deleted_nodes_.back();
|
||||||
deleted_nodes_.pop_back();
|
deleted_nodes_.pop_back();
|
||||||
@ -612,6 +646,8 @@ class RegTree : public Model {
|
|||||||
<< "number of nodes in the tree exceed 2^31";
|
<< "number of nodes in the tree exceed 2^31";
|
||||||
nodes_.resize(param.num_nodes);
|
nodes_.resize(param.num_nodes);
|
||||||
stats_.resize(param.num_nodes);
|
stats_.resize(param.num_nodes);
|
||||||
|
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
|
||||||
|
split_categories_segments_.resize(param.num_nodes);
|
||||||
return nd;
|
return nd;
|
||||||
}
|
}
|
||||||
// delete a tree node, keep the parent field to allow trace back
|
// delete a tree node, keep the parent field to allow trace back
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
|
#include "device_helpers.cuh"
|
||||||
#endif // defined(__CUDACC__)
|
#endif // defined(__CUDACC__)
|
||||||
|
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
@ -54,23 +55,24 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
|
|||||||
*
|
*
|
||||||
* \tparam Direction Whether the bits start from left or from right.
|
* \tparam Direction Whether the bits start from left or from right.
|
||||||
*/
|
*/
|
||||||
template <typename VT, typename Direction>
|
template <typename VT, typename Direction, bool IsConst = false>
|
||||||
struct BitFieldContainer {
|
struct BitFieldContainer {
|
||||||
using value_type = VT; // NOLINT
|
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
|
||||||
using pointer = value_type*; // NOLINT
|
using pointer = value_type*; // NOLINT
|
||||||
|
|
||||||
static value_type constexpr kValueSize = sizeof(value_type) * 8;
|
static value_type constexpr kValueSize = sizeof(value_type) * 8;
|
||||||
static value_type constexpr kOne = 1; // force correct type.
|
static value_type constexpr kOne = 1; // force correct type.
|
||||||
|
|
||||||
struct Pos {
|
struct Pos {
|
||||||
value_type int_pos {0};
|
std::remove_const_t<value_type> int_pos {0};
|
||||||
value_type bit_pos {0};
|
std::remove_const_t<value_type> bit_pos {0};
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
common::Span<value_type> bits_;
|
common::Span<value_type> bits_;
|
||||||
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
|
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
|
||||||
|
|
||||||
|
public:
|
||||||
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
|
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
|
||||||
Pos pos_v;
|
Pos pos_v;
|
||||||
if (pos == 0) {
|
if (pos == 0) {
|
||||||
@ -92,7 +94,7 @@ struct BitFieldContainer {
|
|||||||
/*\brief Compute the size of needed memory allocation. The returned value is in terms
|
/*\brief Compute the size of needed memory allocation. The returned value is in terms
|
||||||
* of number of elements with `BitFieldContainer::value_type'.
|
* of number of elements with `BitFieldContainer::value_type'.
|
||||||
*/
|
*/
|
||||||
static size_t ComputeStorageSize(size_t size) {
|
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
|
||||||
return common::DivRoundUp(size, kValueSize);
|
return common::DivRoundUp(size, kValueSize);
|
||||||
}
|
}
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
@ -134,19 +136,19 @@ struct BitFieldContainer {
|
|||||||
#endif // defined(__CUDA_ARCH__)
|
#endif // defined(__CUDA_ARCH__)
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
__device__ void Set(value_type pos) {
|
__device__ auto Set(value_type pos) {
|
||||||
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
||||||
value_type& value = bits_[pos_v.int_pos];
|
value_type& value = bits_[pos_v.int_pos];
|
||||||
value_type set_bit = kOne << pos_v.bit_pos;
|
value_type set_bit = kOne << pos_v.bit_pos;
|
||||||
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
|
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
|
||||||
AtomicOr(reinterpret_cast<BitFieldAtomicType*>(&value), set_bit);
|
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
|
||||||
}
|
}
|
||||||
__device__ void Clear(value_type pos) {
|
__device__ void Clear(value_type pos) {
|
||||||
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
Pos pos_v = Direction::Shift(ToBitPos(pos));
|
||||||
value_type& value = bits_[pos_v.int_pos];
|
value_type& value = bits_[pos_v.int_pos];
|
||||||
value_type clear_bit = ~(kOne << pos_v.bit_pos);
|
value_type clear_bit = ~(kOne << pos_v.bit_pos);
|
||||||
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
|
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
|
||||||
AtomicAnd(reinterpret_cast<BitFieldAtomicType*>(&value), clear_bit);
|
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
void Set(value_type pos) {
|
void Set(value_type pos) {
|
||||||
@ -165,6 +167,7 @@ struct BitFieldContainer {
|
|||||||
|
|
||||||
XGBOOST_DEVICE bool Check(Pos pos_v) const {
|
XGBOOST_DEVICE bool Check(Pos pos_v) const {
|
||||||
pos_v = Direction::Shift(pos_v);
|
pos_v = Direction::Shift(pos_v);
|
||||||
|
SPAN_LT(pos_v.int_pos, bits_.size());
|
||||||
value_type const value = bits_[pos_v.int_pos];
|
value_type const value = bits_[pos_v.int_pos];
|
||||||
value_type const test_bit = kOne << pos_v.bit_pos;
|
value_type const test_bit = kOne << pos_v.bit_pos;
|
||||||
value_type result = test_bit & value;
|
value_type result = test_bit & value;
|
||||||
@ -179,10 +182,11 @@ struct BitFieldContainer {
|
|||||||
|
|
||||||
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
|
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, BitFieldContainer<VT, Direction> field) {
|
inline friend std::ostream &
|
||||||
|
operator<<(std::ostream &os, BitFieldContainer<VT, Direction, IsConst> field) {
|
||||||
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
|
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
|
||||||
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
|
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
|
||||||
std::bitset<BitFieldContainer<VT, Direction>::kValueSize> bset(field.bits_[i]);
|
std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.bits_[i]);
|
||||||
os << bset << "\n";
|
os << bset << "\n";
|
||||||
}
|
}
|
||||||
return os;
|
return os;
|
||||||
@ -190,9 +194,9 @@ struct BitFieldContainer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Bits start from left most bits (most significant bit).
|
// Bits start from left most bits (most significant bit).
|
||||||
template <typename VT>
|
template <typename VT, bool IsConst = false>
|
||||||
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT>> {
|
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst> {
|
||||||
using Container = BitFieldContainer<VT, LBitsPolicy<VT>>;
|
using Container = BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst>;
|
||||||
using Pos = typename Container::Pos;
|
using Pos = typename Container::Pos;
|
||||||
using value_type = typename Container::value_type; // NOLINT
|
using value_type = typename Container::value_type; // NOLINT
|
||||||
|
|
||||||
@ -215,38 +219,13 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Format: <Direction>BitField<size of underlying type in bits>, underlying type must be unsigned.
|
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
|
||||||
|
// must be unsigned.
|
||||||
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
|
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
|
||||||
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
|
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
|
||||||
|
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
|
||||||
template <typename V, typename D>
|
|
||||||
inline void PrintDeviceBits(std::string name, BitFieldContainer<V, D> field) {
|
|
||||||
std::cout << "Bits: " << name << std::endl;
|
|
||||||
std::vector<typename BitFieldContainer<V, D>::value_type> h_field_bits(field.bits_.size());
|
|
||||||
thrust::copy(thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(field.bits_.data()),
|
|
||||||
thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(
|
|
||||||
field.bits_.data() + field.bits_.size()),
|
|
||||||
h_field_bits.data());
|
|
||||||
BitFieldContainer<V, D> h_field;
|
|
||||||
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
|
|
||||||
std::cout << h_field;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
|
|
||||||
std::cout << name << std::endl;
|
|
||||||
std::vector<int32_t> h_list(list.size());
|
|
||||||
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
|
|
||||||
thrust::device_ptr<int32_t>(list.data() + list.size()),
|
|
||||||
h_list.data());
|
|
||||||
for (auto v : h_list) {
|
|
||||||
std::cout << v << ", ";
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // defined(__CUDACC__)
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_COMMON_BITFIELD_H_
|
#endif // XGBOOST_COMMON_BITFIELD_H_
|
||||||
|
|||||||
50
src/common/categorical.h
Normal file
50
src/common/categorical.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2020 by XGBoost Contributors
|
||||||
|
* \file categorical.h
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
|
#define XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
|
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/span.h"
|
||||||
|
#include "xgboost/parameter.h"
|
||||||
|
#include "bitfield.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
// Cast the categorical type.
|
||||||
|
template <typename T>
|
||||||
|
XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
|
||||||
|
return static_cast<bst_cat_t>(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* \brief Whether is fidx a categorical feature.
|
||||||
|
*
|
||||||
|
* \param ft Feature type for all features.
|
||||||
|
* \param fidx Feature index.
|
||||||
|
* \return Whether feature pointed by fidx is categorical feature.
|
||||||
|
*/
|
||||||
|
inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx) {
|
||||||
|
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* \brief Whether should it traverse to left branch of a tree.
|
||||||
|
*
|
||||||
|
* For one hot split, go to left if it's NOT the matching category.
|
||||||
|
*/
|
||||||
|
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
|
||||||
|
auto pos = CLBitField32::ToBitPos(cat);
|
||||||
|
if (pos.int_pos >= cats.size()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
CLBitField32 const s_cats(cats);
|
||||||
|
return !s_cats.Check(cat);
|
||||||
|
}
|
||||||
|
|
||||||
|
using CatBitField = LBitField32;
|
||||||
|
using KCatBitField = CLBitField32;
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_COMMON_CATEGORICAL_H_
|
||||||
@ -275,6 +275,9 @@ Json& JsonNumber::operator[](int ind) {
|
|||||||
|
|
||||||
bool JsonNumber::operator==(Value const& rhs) const {
|
bool JsonNumber::operator==(Value const& rhs) const {
|
||||||
if (!IsA<JsonNumber>(&rhs)) { return false; }
|
if (!IsA<JsonNumber>(&rhs)) { return false; }
|
||||||
|
if (std::isinf(number_)) {
|
||||||
|
return std::isinf(Cast<JsonNumber const>(&rhs)->GetNumber());
|
||||||
|
}
|
||||||
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
|
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -199,8 +199,10 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
|
|||||||
types->emplace_back(FeatureType::kNumerical);
|
types->emplace_back(FeatureType::kNumerical);
|
||||||
} else if (elem == "q") {
|
} else if (elem == "q") {
|
||||||
types->emplace_back(FeatureType::kNumerical);
|
types->emplace_back(FeatureType::kNumerical);
|
||||||
|
} else if (elem == "categorical") {
|
||||||
|
types->emplace_back(FeatureType::kCategorical);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "All feature_types must be {int, float, i, q}";
|
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, categorical}.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
|
#include "../common/categorical.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// register tree parameter
|
// register tree parameter
|
||||||
@ -662,6 +663,53 @@ bst_node_t RegTree::GetNumSplitNodes() const {
|
|||||||
return splits;
|
return splits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
|
||||||
|
bool default_left, bst_float base_weight,
|
||||||
|
bst_float left_leaf_weight,
|
||||||
|
bst_float right_leaf_weight, bst_float loss_change,
|
||||||
|
float sum_hess, float left_sum, float right_sum,
|
||||||
|
bst_node_t leaf_right_child) {
|
||||||
|
int pleft = this->AllocNode();
|
||||||
|
int pright = this->AllocNode();
|
||||||
|
auto &node = nodes_[nid];
|
||||||
|
CHECK(node.IsLeaf());
|
||||||
|
node.SetLeftChild(pleft);
|
||||||
|
node.SetRightChild(pright);
|
||||||
|
nodes_[node.LeftChild()].SetParent(nid, true);
|
||||||
|
nodes_[node.RightChild()].SetParent(nid, false);
|
||||||
|
node.SetSplit(split_index, split_value, default_left);
|
||||||
|
|
||||||
|
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
||||||
|
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
||||||
|
|
||||||
|
this->Stat(nid) = {loss_change, sum_hess, base_weight};
|
||||||
|
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
|
||||||
|
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
|
||||||
|
|
||||||
|
this->split_types_.at(nid) = FeatureType::kNumerical;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||||
|
common::Span<uint32_t> split_cat, bool default_left,
|
||||||
|
bst_float base_weight,
|
||||||
|
bst_float left_leaf_weight,
|
||||||
|
bst_float right_leaf_weight,
|
||||||
|
bst_float loss_change, float sum_hess,
|
||||||
|
float left_sum, float right_sum) {
|
||||||
|
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
default_left, base_weight,
|
||||||
|
left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
|
||||||
|
left_sum, right_sum);
|
||||||
|
|
||||||
|
size_t orig_size = split_categories_.size();
|
||||||
|
this->split_categories_.resize(orig_size + split_cat.size());
|
||||||
|
std::copy(split_cat.data(), split_cat.data() + split_cat.size(),
|
||||||
|
split_categories_.begin() + orig_size);
|
||||||
|
this->split_types_.at(nid) = FeatureType::kCategorical;
|
||||||
|
this->split_categories_segments_.at(nid).beg = orig_size;
|
||||||
|
this->split_categories_segments_.at(nid).size = split_cat.size();
|
||||||
|
}
|
||||||
|
|
||||||
void RegTree::Load(dmlc::Stream* fi) {
|
void RegTree::Load(dmlc::Stream* fi) {
|
||||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||||
@ -751,11 +799,24 @@ void RegTree::LoadModel(Json const& in) {
|
|||||||
auto const& default_left = get<Array const>(in["default_left"]);
|
auto const& default_left = get<Array const>(in["default_left"]);
|
||||||
CHECK_EQ(default_left.size(), n_nodes);
|
CHECK_EQ(default_left.size(), n_nodes);
|
||||||
|
|
||||||
|
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
|
||||||
|
std::vector<Json> split_type;
|
||||||
|
std::vector<Json> categories;
|
||||||
|
if (has_cat) {
|
||||||
|
split_type = get<Array const>(in["split_type"]);
|
||||||
|
categories = get<Array const>(in["categories"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
stats_.clear();
|
stats_.clear();
|
||||||
nodes_.clear();
|
nodes_.clear();
|
||||||
|
|
||||||
stats_.resize(n_nodes);
|
stats_.resize(n_nodes);
|
||||||
nodes_.resize(n_nodes);
|
nodes_.resize(n_nodes);
|
||||||
|
split_types_.resize(n_nodes);
|
||||||
|
split_categories_segments_.resize(n_nodes);
|
||||||
|
|
||||||
|
CHECK_EQ(n_nodes, split_categories_segments_.size());
|
||||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||||
auto& s = stats_[i];
|
auto& s = stats_[i];
|
||||||
s.loss_chg = get<Number const>(loss_changes[i]);
|
s.loss_chg = get<Number const>(loss_changes[i]);
|
||||||
@ -771,6 +832,31 @@ void RegTree::LoadModel(Json const& in) {
|
|||||||
float cond { get<Number const>(conds[i]) };
|
float cond { get<Number const>(conds[i]) };
|
||||||
bool dft_left { get<Boolean const>(default_left[i]) };
|
bool dft_left { get<Boolean const>(default_left[i]) };
|
||||||
n = Node{left, right, parent, ind, cond, dft_left};
|
n = Node{left, right, parent, ind, cond, dft_left};
|
||||||
|
|
||||||
|
if (has_cat) {
|
||||||
|
split_types_[i] =
|
||||||
|
static_cast<FeatureType>(get<Integer const>(split_type[i]));
|
||||||
|
auto const& j_categories = get<Array const>(categories[i]);
|
||||||
|
bst_cat_t max_cat { std::numeric_limits<bst_cat_t>::min() };
|
||||||
|
for (auto const& j_cat : j_categories) {
|
||||||
|
auto cat = common::AsCat(get<Integer const>(j_cat));
|
||||||
|
max_cat = std::max(max_cat, cat);
|
||||||
|
}
|
||||||
|
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||||
|
? 0
|
||||||
|
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||||
|
std::vector<uint32_t> cat_bits_storage(size);
|
||||||
|
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||||
|
for (auto const& j_cat : j_categories) {
|
||||||
|
cat_bits.Set(common::AsCat(get<Integer const>(j_cat)));
|
||||||
|
}
|
||||||
|
auto begin = split_categories_.size();
|
||||||
|
split_categories_.resize(begin + cat_bits_storage.size());
|
||||||
|
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
|
||||||
|
split_categories_.begin() + begin);
|
||||||
|
split_categories_segments_[i].beg = begin;
|
||||||
|
split_categories_segments_[i].size = cat_bits_storage.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted_nodes_.clear();
|
deleted_nodes_.clear();
|
||||||
@ -811,8 +897,11 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
std::vector<Json> indices(n_nodes);
|
std::vector<Json> indices(n_nodes);
|
||||||
std::vector<Json> conds(n_nodes);
|
std::vector<Json> conds(n_nodes);
|
||||||
std::vector<Json> default_left(n_nodes);
|
std::vector<Json> default_left(n_nodes);
|
||||||
|
std::vector<Json> split_type(n_nodes);
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
std::vector<Json> categories(n_nodes);
|
||||||
|
|
||||||
|
for (bst_node_t i = 0; i < n_nodes; ++i) {
|
||||||
auto const& s = stats_[i];
|
auto const& s = stats_[i];
|
||||||
loss_changes[i] = s.loss_chg;
|
loss_changes[i] = s.loss_chg;
|
||||||
sum_hessian[i] = s.sum_hess;
|
sum_hessian[i] = s.sum_hess;
|
||||||
@ -826,6 +915,24 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
indices[i] = static_cast<I>(n.SplitIndex());
|
indices[i] = static_cast<I>(n.SplitIndex());
|
||||||
conds[i] = n.SplitCond();
|
conds[i] = n.SplitCond();
|
||||||
default_left[i] = n.DefaultLeft();
|
default_left[i] = n.DefaultLeft();
|
||||||
|
|
||||||
|
std::vector<Json> categories_temp;
|
||||||
|
// This condition is only for being compatibale with older version of XGBoost model
|
||||||
|
// that doesn't have categorical data support.
|
||||||
|
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
|
||||||
|
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
|
||||||
|
split_type[i] = static_cast<I>(this->NodeSplitType(i));
|
||||||
|
auto beg = this->GetSplitCategoriesPtr().at(i).beg;
|
||||||
|
auto size = this->GetSplitCategoriesPtr().at(i).size;
|
||||||
|
auto node_categories = this->GetSplitCategories().subspan(beg, size);
|
||||||
|
common::KCatBitField const cat_bits(node_categories);
|
||||||
|
for (size_t i = 0; i < cat_bits.Size(); ++i) {
|
||||||
|
if (cat_bits.Check(i)) {
|
||||||
|
categories_temp.emplace_back(static_cast<Integer::Int>(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
categories[i] = Array(categories_temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
out["loss_changes"] = std::move(loss_changes);
|
out["loss_changes"] = std::move(loss_changes);
|
||||||
@ -839,6 +946,12 @@ void RegTree::SaveModel(Json* p_out) const {
|
|||||||
out["split_indices"] = std::move(indices);
|
out["split_indices"] = std::move(indices);
|
||||||
out["split_conditions"] = std::move(conds);
|
out["split_conditions"] = std::move(conds);
|
||||||
out["default_left"] = std::move(default_left);
|
out["default_left"] = std::move(default_left);
|
||||||
|
|
||||||
|
out["categories"] = categories;
|
||||||
|
|
||||||
|
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
|
||||||
|
out["split_type"] = std::move(split_type);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::FillNodeMeanValues() {
|
void RegTree::FillNodeMeanValues() {
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
// Copyright by Contributors
|
// Copyright by Contributors
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/tree_model.h>
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "dmlc/filesystem.h"
|
#include "dmlc/filesystem.h"
|
||||||
#include "xgboost/json_io.h"
|
#include "xgboost/json_io.h"
|
||||||
|
#include "xgboost/tree_model.h"
|
||||||
|
#include "../../../src/common/bitfield.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
|
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
|
||||||
@ -82,7 +83,7 @@ TEST(Tree, Load) {
|
|||||||
tree.Load(fi.get());
|
tree.Load(fi.get());
|
||||||
EXPECT_EQ(tree.GetDepth(1), 1);
|
EXPECT_EQ(tree.GetDepth(1), 1);
|
||||||
EXPECT_EQ(tree[0].SplitCond(), 0.5f);
|
EXPECT_EQ(tree[0].SplitCond(), 0.5f);
|
||||||
EXPECT_EQ(tree[0].SplitIndex(), 5);
|
EXPECT_EQ(tree[0].SplitIndex(), 5ul);
|
||||||
EXPECT_EQ(tree[1].LeafValue(), 0.1f);
|
EXPECT_EQ(tree[1].LeafValue(), 0.1f);
|
||||||
EXPECT_TRUE(tree[1].IsLeaf());
|
EXPECT_TRUE(tree[1].IsLeaf());
|
||||||
}
|
}
|
||||||
@ -105,6 +106,51 @@ TEST(Tree, AllocateNode) {
|
|||||||
ASSERT_TRUE(nodes.at(2).IsLeaf());
|
ASSERT_TRUE(nodes.at(2).IsLeaf());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Tree, ExpandCategoricalFeature) {
|
||||||
|
{
|
||||||
|
RegTree tree;
|
||||||
|
tree.ExpandCategorical(0, 0, {}, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||||
|
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||||
|
ASSERT_EQ(tree.GetNodes().size(), 3ul);
|
||||||
|
ASSERT_EQ(tree.GetNumLeaves(), 2);
|
||||||
|
ASSERT_EQ(tree.GetSplitTypes().size(), 3ul);
|
||||||
|
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
|
||||||
|
ASSERT_EQ(tree.GetSplitTypes()[1], FeatureType::kNumerical);
|
||||||
|
ASSERT_EQ(tree.GetSplitTypes()[2], FeatureType::kNumerical);
|
||||||
|
ASSERT_EQ(tree.GetSplitCategories().size(), 0ul);
|
||||||
|
ASSERT_TRUE(std::isnan(tree[0].SplitCond()));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
RegTree tree;
|
||||||
|
bst_cat_t cat = 33;
|
||||||
|
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat+1));
|
||||||
|
LBitField32 bitset {split_cats};
|
||||||
|
bitset.Set(cat);
|
||||||
|
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
|
||||||
|
/*left_sum=*/3.0, /*right_sum=*/4.0);
|
||||||
|
auto categories = tree.GetSplitCategories();
|
||||||
|
auto segments = tree.GetSplitCategoriesPtr();
|
||||||
|
auto got = categories.subspan(segments[0].beg, segments[0].size);
|
||||||
|
ASSERT_TRUE(std::equal(got.cbegin(), got.cend(), split_cats.cbegin()));
|
||||||
|
|
||||||
|
Json out{Object()};
|
||||||
|
tree.SaveModel(&out);
|
||||||
|
|
||||||
|
RegTree loaded_tree;
|
||||||
|
loaded_tree.LoadModel(out);
|
||||||
|
|
||||||
|
auto const& cat_ptr = loaded_tree.GetSplitCategoriesPtr();
|
||||||
|
ASSERT_EQ(cat_ptr.size(), 3ul);
|
||||||
|
ASSERT_EQ(cat_ptr[0].beg, 0ul);
|
||||||
|
ASSERT_EQ(cat_ptr[0].size, 2ul);
|
||||||
|
|
||||||
|
auto loaded_categories = loaded_tree.GetSplitCategories();
|
||||||
|
auto loaded_root = loaded_categories.subspan(cat_ptr[0].beg, cat_ptr[0].size);
|
||||||
|
ASSERT_TRUE(std::equal(loaded_root.begin(), loaded_root.end(), split_cats.begin()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
RegTree ConstructTree() {
|
RegTree ConstructTree() {
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
tree.ExpandNode(
|
tree.ExpandNode(
|
||||||
@ -123,6 +169,7 @@ RegTree ConstructTree() {
|
|||||||
/*right_sum=*/0.0f);
|
/*right_sum=*/0.0f);
|
||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(Tree, DumpJson) {
|
TEST(Tree, DumpJson) {
|
||||||
auto tree = ConstructTree();
|
auto tree = ConstructTree();
|
||||||
@ -133,14 +180,14 @@ TEST(Tree, DumpJson) {
|
|||||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||||
n_leaves++;
|
n_leaves++;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(n_leaves, 4);
|
ASSERT_EQ(n_leaves, 4ul);
|
||||||
|
|
||||||
size_t n_conditions = 0;
|
size_t n_conditions = 0;
|
||||||
iter = 0;
|
iter = 0;
|
||||||
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
|
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
|
||||||
n_conditions++;
|
n_conditions++;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(n_conditions, 3);
|
ASSERT_EQ(n_conditions, 3ul);
|
||||||
|
|
||||||
fmap.PushBack(0, "feat_0", "i");
|
fmap.PushBack(0, "feat_0", "i");
|
||||||
fmap.PushBack(1, "feat_1", "q");
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
@ -156,7 +203,7 @@ TEST(Tree, DumpJson) {
|
|||||||
|
|
||||||
|
|
||||||
auto j_tree = Json::Load({str.c_str(), str.size()});
|
auto j_tree = Json::Load({str.c_str(), str.size()});
|
||||||
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
|
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2ul);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Tree, DumpText) {
|
TEST(Tree, DumpText) {
|
||||||
@ -168,14 +215,14 @@ TEST(Tree, DumpText) {
|
|||||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||||
n_leaves++;
|
n_leaves++;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(n_leaves, 4);
|
ASSERT_EQ(n_leaves, 4ul);
|
||||||
|
|
||||||
iter = 0;
|
iter = 0;
|
||||||
size_t n_conditions = 0;
|
size_t n_conditions = 0;
|
||||||
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
|
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
|
||||||
n_conditions++;
|
n_conditions++;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(n_conditions, 3);
|
ASSERT_EQ(n_conditions, 3ul);
|
||||||
|
|
||||||
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
|
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
|
||||||
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
||||||
@ -204,14 +251,14 @@ TEST(Tree, DumpDot) {
|
|||||||
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||||
n_leaves++;
|
n_leaves++;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(n_leaves, 4);
|
ASSERT_EQ(n_leaves, 4ul);
|
||||||
|
|
||||||
size_t n_edges = 0;
|
size_t n_edges = 0;
|
||||||
iter = 0;
|
iter = 0;
|
||||||
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
|
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
|
||||||
n_edges++;
|
n_edges++;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(n_edges, 6);
|
ASSERT_EQ(n_edges, 6ul);
|
||||||
|
|
||||||
fmap.PushBack(0, "feat_0", "i");
|
fmap.PushBack(0, "feat_0", "i");
|
||||||
fmap.PushBack(1, "feat_1", "q");
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
@ -238,12 +285,12 @@ TEST(Tree, JsonIO) {
|
|||||||
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
|
ASSERT_EQ(get<String>(tparam["num_nodes"]), "3");
|
||||||
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
|
ASSERT_EQ(get<String>(tparam["size_leaf_vector"]), "0");
|
||||||
|
|
||||||
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3);
|
ASSERT_EQ(get<Array const>(j_tree["left_children"]).size(), 3ul);
|
||||||
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3);
|
ASSERT_EQ(get<Array const>(j_tree["right_children"]).size(), 3ul);
|
||||||
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3);
|
ASSERT_EQ(get<Array const>(j_tree["parents"]).size(), 3ul);
|
||||||
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3);
|
ASSERT_EQ(get<Array const>(j_tree["split_indices"]).size(), 3ul);
|
||||||
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3);
|
ASSERT_EQ(get<Array const>(j_tree["split_conditions"]).size(), 3ul);
|
||||||
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3);
|
ASSERT_EQ(get<Array const>(j_tree["default_left"]).size(), 3ul);
|
||||||
|
|
||||||
RegTree loaded_tree;
|
RegTree loaded_tree;
|
||||||
loaded_tree.LoadModel(j_tree);
|
loaded_tree.LoadModel(j_tree);
|
||||||
@ -268,5 +315,4 @@ TEST(Tree, JsonIO) {
|
|||||||
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
|
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
|
||||||
ASSERT_TRUE(tree.Equal(loaded_tree));
|
ASSERT_TRUE(tree.Equal(loaded_tree));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user