diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 3cbf7f991..2f2f9efc3 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -245,11 +245,12 @@ test_that("training continuation works", { expect_equal(bst$raw, bst2$raw) expect_equal(dim(bst2$evaluation_log), c(2, 2)) # test continuing from a model in file - xgb.save(bst1, "xgboost.model") - bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.model") + xgb.save(bst1, "xgboost.json") + bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.json") if (!windows_flag && !solaris_flag) expect_equal(bst$raw, bst2$raw) expect_equal(dim(bst2$evaluation_log), c(2, 2)) + file.remove("xgboost.json") }) test_that("model serialization works", { diff --git a/R-package/tests/testthat/test_callbacks.R b/R-package/tests/testthat/test_callbacks.R index 9016c1bcb..fd42f519a 100644 --- a/R-package/tests/testthat/test_callbacks.R +++ b/R-package/tests/testthat/test_callbacks.R @@ -173,16 +173,16 @@ test_that("cb.reset.parameters works as expected", { }) test_that("cb.save.model works as expected", { - files <- c('xgboost_01.model', 'xgboost_02.model', 'xgboost.model') + files <- c('xgboost_01.json', 'xgboost_02.json', 'xgboost.json') for (f in files) if (file.exists(f)) file.remove(f) bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0, - save_period = 1, save_name = "xgboost_%02d.model") - expect_true(file.exists('xgboost_01.model')) - expect_true(file.exists('xgboost_02.model')) - b1 <- xgb.load('xgboost_01.model') + save_period = 1, save_name = "xgboost_%02d.json") + expect_true(file.exists('xgboost_01.json')) + expect_true(file.exists('xgboost_02.json')) + b1 <- xgb.load('xgboost_01.json') expect_equal(xgb.ntree(b1), 1) - b2 <- xgb.load('xgboost_02.model') + b2 <- xgb.load('xgboost_02.json') expect_equal(xgb.ntree(b2), 2) xgb.config(b2) <- xgb.config(bst) @@ -191,9 +191,9 @@ test_that("cb.save.model works as expected", { # save_period = 0 saves the last iteration's model bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0, - save_period = 0) - expect_true(file.exists('xgboost.model')) - b2 <- xgb.load('xgboost.model') + save_period = 0, save_name = 'xgboost.json') + expect_true(file.exists('xgboost.json')) + b2 <- xgb.load('xgboost.json') xgb.config(b2) <- xgb.config(bst) expect_equal(bst$raw, b2$raw) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 480242611..cf30b969c 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -109,7 +109,8 @@ using bst_int = int32_t; // NOLINT using bst_ulong = uint64_t; // NOLINT /*! \brief float type, used for storing statistics */ using bst_float = float; // NOLINT - +/*! \brief Categorical value type. */ +using bst_cat_t = int32_t; // NOLINT /*! \brief Type for data column (feature) index. */ using bst_feature_t = uint32_t; // NOLINT /*! \brief Type for data row index. diff --git a/include/xgboost/data.h b/include/xgboost/data.h index f74dbd2c5..0a0459adc 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -35,7 +35,8 @@ enum class DataType : uint8_t { }; enum class FeatureType : uint8_t { - kNumerical + kNumerical, + kCategorical }; /*! @@ -309,12 +310,6 @@ class SparsePage { } } - /*! - * \brief Push row block into the page. - * \param batch the row batch. - */ - void Push(const dmlc::RowBlock& batch); - /** * \brief Pushes external data batch onto this page * diff --git a/include/xgboost/span.h b/include/xgboost/span.h index ed8c97bd4..6187f396f 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -101,6 +101,18 @@ namespace common { } while (0); #endif // __CUDA_ARCH__ +#if defined(__CUDA_ARCH__) +#define SPAN_LT(lhs, rhs) \ + if (!((lhs) < (rhs))) { \ + printf("%lu < %lu failed\n", static_cast(lhs), \ + static_cast(rhs)); \ + asm("trap;"); \ + } +#else +#define SPAN_LT(lhs, rhs) \ + SPAN_CHECK((lhs) < (rhs)) +#endif // defined(__CUDA_ARCH__) + namespace detail { /*! * By default, XGBoost uses uint32_t for indexing data. int64_t covers all @@ -515,7 +527,7 @@ class Span { } XGBOOST_DEVICE reference operator[](index_type _idx) const { - SPAN_CHECK(_idx < size()); + SPAN_LT(_idx, size()); return data()[_idx]; } @@ -575,7 +587,6 @@ class Span { detail::ExtentValue::value> { SPAN_CHECK((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); - return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; } diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index fd9c69df3..60aa9ce16 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -318,6 +318,8 @@ class RegTree : public Model { param.num_deleted = 0; nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); + split_types_.resize(param.num_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(param.num_nodes); for (int i = 0; i < param.num_nodes; i ++) { nodes_[i].SetLeaf(0.0f); nodes_[i].SetParent(kInvalidNodeId); @@ -412,30 +414,33 @@ class RegTree : public Model { * \param leaf_right_child The right child index of leaf, by default kInvalidNodeId, * some updaters use the right child index of leaf as a marker */ - void ExpandNode(int nid, unsigned split_index, bst_float split_value, + void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, - bst_node_t leaf_right_child = kInvalidNodeId) { - int pleft = this->AllocNode(); - int pright = this->AllocNode(); - auto &node = nodes_[nid]; - CHECK(node.IsLeaf()); - node.SetLeftChild(pleft); - node.SetRightChild(pright); - nodes_[node.LeftChild()].SetParent(nid, true); - nodes_[node.RightChild()].SetParent(nid, false); - node.SetSplit(split_index, split_value, - default_left); + bst_node_t leaf_right_child = kInvalidNodeId); - nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child); - nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child); - - this->Stat(nid) = {loss_change, sum_hess, base_weight}; - this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight}; - this->Stat(pright) = {0.0f, right_sum, right_leaf_weight}; - } + /** + * \brief Expands a leaf node with categories + * + * \param nid The node index to expand. + * \param split_index Feature index of the split. + * \param split_cat The bitset containing categories + * \param default_left True to default left. + * \param base_weight The base weight, before learning rate. + * \param left_leaf_weight The left leaf weight for prediction, modified by learning rate. + * \param right_leaf_weight The right leaf weight for prediction, modified by learning rate. + * \param loss_change The loss change. + * \param sum_hess The sum hess. + * \param left_sum The sum hess of left leaf. + * \param right_sum The sum hess of right leaf. + */ + void ExpandCategorical(bst_node_t nid, unsigned split_index, + common::Span split_cat, bool default_left, + bst_float base_weight, bst_float left_leaf_weight, + bst_float right_leaf_weight, bst_float loss_change, + float sum_hess, float left_sum, float right_sum); /*! * \brief get current depth @@ -588,6 +593,28 @@ class RegTree : public Model { * \brief calculate the mean value for each node, required for feature contributions */ void FillNodeMeanValues(); + /*! + * \brief Get split type for a node. + * \param nidx Index of node. + * \return The type of this split. For leaf node it's always kNumerical. + */ + FeatureType NodeSplitType(bst_node_t nidx) const { + return split_types_.at(nidx); + } + /*! + * \brief Get split types for all nodes. + */ + std::vector const &GetSplitTypes() const { return split_types_; } + common::Span GetSplitCategories() const { return split_categories_; } + auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; } + + // The fields of split_categories_segments_[i] are set such that + // the range split_categories_[beg:(beg+size)] stores the bitset for + // the matching categories for the i-th node. + struct Segment { + size_t beg {0}; + size_t size {0}; + }; private: // vector of nodes @@ -597,9 +624,16 @@ class RegTree : public Model { // stats of nodes std::vector stats_; std::vector node_mean_values_; + std::vector split_types_; + + // Categories for each internal node. + std::vector split_categories_; + // Ptr to split categories of each node. + std::vector split_categories_segments_; + // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize - int AllocNode() { + bst_node_t AllocNode() { if (param.num_deleted != 0) { int nid = deleted_nodes_.back(); deleted_nodes_.pop_back(); @@ -612,6 +646,8 @@ class RegTree : public Model { << "number of nodes in the tree exceed 2^31"; nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); + split_types_.resize(param.num_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(param.num_nodes); return nd; } // delete a tree node, keep the parent field to allow trace back diff --git a/src/common/bitfield.h b/src/common/bitfield.h index 4353a5269..c727360b3 100644 --- a/src/common/bitfield.h +++ b/src/common/bitfield.h @@ -16,6 +16,7 @@ #if defined(__CUDACC__) #include #include +#include "device_helpers.cuh" #endif // defined(__CUDACC__) #include "xgboost/span.h" @@ -54,23 +55,24 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr * * \tparam Direction Whether the bits start from left or from right. */ -template +template struct BitFieldContainer { - using value_type = VT; // NOLINT + using value_type = std::conditional_t; // NOLINT using pointer = value_type*; // NOLINT static value_type constexpr kValueSize = sizeof(value_type) * 8; static value_type constexpr kOne = 1; // force correct type. struct Pos { - value_type int_pos {0}; - value_type bit_pos {0}; + std::remove_const_t int_pos {0}; + std::remove_const_t bit_pos {0}; }; private: common::Span bits_; static_assert(!std::is_signed::value, "Must use unsiged type as underlying storage."); + public: XGBOOST_DEVICE static Pos ToBitPos(value_type pos) { Pos pos_v; if (pos == 0) { @@ -92,7 +94,7 @@ struct BitFieldContainer { /*\brief Compute the size of needed memory allocation. The returned value is in terms * of number of elements with `BitFieldContainer::value_type'. */ - static size_t ComputeStorageSize(size_t size) { + XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) { return common::DivRoundUp(size, kValueSize); } #if defined(__CUDA_ARCH__) @@ -134,19 +136,19 @@ struct BitFieldContainer { #endif // defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) - __device__ void Set(value_type pos) { + __device__ auto Set(value_type pos) { Pos pos_v = Direction::Shift(ToBitPos(pos)); value_type& value = bits_[pos_v.int_pos]; value_type set_bit = kOne << pos_v.bit_pos; - static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), ""); - AtomicOr(reinterpret_cast(&value), set_bit); + using Type = typename dh::detail::AtomicDispatcher::Type; + atomicOr(reinterpret_cast(&value), set_bit); } __device__ void Clear(value_type pos) { Pos pos_v = Direction::Shift(ToBitPos(pos)); value_type& value = bits_[pos_v.int_pos]; value_type clear_bit = ~(kOne << pos_v.bit_pos); - static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), ""); - AtomicAnd(reinterpret_cast(&value), clear_bit); + using Type = typename dh::detail::AtomicDispatcher::Type; + atomicAnd(reinterpret_cast(&value), clear_bit); } #else void Set(value_type pos) { @@ -165,6 +167,7 @@ struct BitFieldContainer { XGBOOST_DEVICE bool Check(Pos pos_v) const { pos_v = Direction::Shift(pos_v); + SPAN_LT(pos_v.int_pos, bits_.size()); value_type const value = bits_[pos_v.int_pos]; value_type const test_bit = kOne << pos_v.bit_pos; value_type result = test_bit & value; @@ -179,10 +182,11 @@ struct BitFieldContainer { XGBOOST_DEVICE pointer Data() const { return bits_.data(); } - friend std::ostream& operator<<(std::ostream& os, BitFieldContainer field) { + inline friend std::ostream & + operator<<(std::ostream &os, BitFieldContainer field) { os << "Bits " << "storage size: " << field.bits_.size() << "\n"; for (typename common::Span::index_type i = 0; i < field.bits_.size(); ++i) { - std::bitset::kValueSize> bset(field.bits_[i]); + std::bitset::kValueSize> bset(field.bits_[i]); os << bset << "\n"; } return os; @@ -190,9 +194,9 @@ struct BitFieldContainer { }; // Bits start from left most bits (most significant bit). -template -struct LBitsPolicy : public BitFieldContainer> { - using Container = BitFieldContainer>; +template +struct LBitsPolicy : public BitFieldContainer, IsConst> { + using Container = BitFieldContainer, IsConst>; using Pos = typename Container::Pos; using value_type = typename Container::value_type; // NOLINT @@ -215,38 +219,13 @@ struct RBitsPolicy : public BitFieldContainer> { } }; -// Format: BitField, underlying type must be unsigned. +// Format: BitField, underlying type +// must be unsigned. using LBitField64 = BitFieldContainer>; using RBitField8 = BitFieldContainer>; -#if defined(__CUDACC__) - -template -inline void PrintDeviceBits(std::string name, BitFieldContainer field) { - std::cout << "Bits: " << name << std::endl; - std::vector::value_type> h_field_bits(field.bits_.size()); - thrust::copy(thrust::device_ptr::value_type>(field.bits_.data()), - thrust::device_ptr::value_type>( - field.bits_.data() + field.bits_.size()), - h_field_bits.data()); - BitFieldContainer h_field; - h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()}; - std::cout << h_field; -} - -inline void PrintDeviceStorage(std::string name, common::Span list) { - std::cout << name << std::endl; - std::vector h_list(list.size()); - thrust::copy(thrust::device_ptr(list.data()), - thrust::device_ptr(list.data() + list.size()), - h_list.data()); - for (auto v : h_list) { - std::cout << v << ", "; - } - std::cout << std::endl; -} - -#endif // defined(__CUDACC__) +using LBitField32 = BitFieldContainer>; +using CLBitField32 = BitFieldContainer, true>; } // namespace xgboost #endif // XGBOOST_COMMON_BITFIELD_H_ diff --git a/src/common/categorical.h b/src/common/categorical.h new file mode 100644 index 000000000..02899a901 --- /dev/null +++ b/src/common/categorical.h @@ -0,0 +1,50 @@ +/*! + * Copyright 2020 by XGBoost Contributors + * \file categorical.h + */ +#ifndef XGBOOST_COMMON_CATEGORICAL_H_ +#define XGBOOST_COMMON_CATEGORICAL_H_ + +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/span.h" +#include "xgboost/parameter.h" +#include "bitfield.h" + +namespace xgboost { +namespace common { +// Cast the categorical type. +template +XGBOOST_DEVICE bst_cat_t AsCat(T const& v) { + return static_cast(v); +} + +/* \brief Whether is fidx a categorical feature. + * + * \param ft Feature type for all features. + * \param fidx Feature index. + * \return Whether feature pointed by fidx is categorical feature. + */ +inline XGBOOST_DEVICE bool IsCat(Span ft, bst_feature_t fidx) { + return !ft.empty() && ft[fidx] == FeatureType::kCategorical; +} + +/* \brief Whether should it traverse to left branch of a tree. + * + * For one hot split, go to left if it's NOT the matching category. + */ +inline XGBOOST_DEVICE bool Decision(common::Span cats, bst_cat_t cat) { + auto pos = CLBitField32::ToBitPos(cat); + if (pos.int_pos >= cats.size()) { + return true; + } + CLBitField32 const s_cats(cats); + return !s_cats.Check(cat); +} + +using CatBitField = LBitField32; +using KCatBitField = CLBitField32; +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_CATEGORICAL_H_ diff --git a/src/common/json.cc b/src/common/json.cc index 18d8694d1..98a6abbd6 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -275,6 +275,9 @@ Json& JsonNumber::operator[](int ind) { bool JsonNumber::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } + if (std::isinf(number_)) { + return std::isinf(Cast(&rhs)->GetNumber()); + } return std::abs(number_ - Cast(&rhs)->GetNumber()) < kRtEps; } diff --git a/src/data/data.cc b/src/data/data.cc index d7d18f189..ad74008eb 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -199,8 +199,10 @@ void LoadFeatureType(std::vectorconst& type_names, std::vectoremplace_back(FeatureType::kNumerical); } else if (elem == "q") { types->emplace_back(FeatureType::kNumerical); + } else if (elem == "categorical") { + types->emplace_back(FeatureType::kCategorical); } else { - LOG(FATAL) << "All feature_types must be {int, float, i, q}"; + LOG(FATAL) << "All feature_types must be one of {int, float, i, q, categorical}."; } } } diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 7f9721aef..27521c68c 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -18,6 +18,7 @@ #include "param.h" #include "../common/common.h" +#include "../common/categorical.h" namespace xgboost { // register tree parameter @@ -662,6 +663,53 @@ bst_node_t RegTree::GetNumSplitNodes() const { return splits; } +void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, + bool default_left, bst_float base_weight, + bst_float left_leaf_weight, + bst_float right_leaf_weight, bst_float loss_change, + float sum_hess, float left_sum, float right_sum, + bst_node_t leaf_right_child) { + int pleft = this->AllocNode(); + int pright = this->AllocNode(); + auto &node = nodes_[nid]; + CHECK(node.IsLeaf()); + node.SetLeftChild(pleft); + node.SetRightChild(pright); + nodes_[node.LeftChild()].SetParent(nid, true); + nodes_[node.RightChild()].SetParent(nid, false); + node.SetSplit(split_index, split_value, default_left); + + nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child); + nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child); + + this->Stat(nid) = {loss_change, sum_hess, base_weight}; + this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight}; + this->Stat(pright) = {0.0f, right_sum, right_leaf_weight}; + + this->split_types_.at(nid) = FeatureType::kNumerical; +} + +void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index, + common::Span split_cat, bool default_left, + bst_float base_weight, + bst_float left_leaf_weight, + bst_float right_leaf_weight, + bst_float loss_change, float sum_hess, + float left_sum, float right_sum) { + this->ExpandNode(nid, split_index, std::numeric_limits::quiet_NaN(), + default_left, base_weight, + left_leaf_weight, right_leaf_weight, loss_change, sum_hess, + left_sum, right_sum); + + size_t orig_size = split_categories_.size(); + this->split_categories_.resize(orig_size + split_cat.size()); + std::copy(split_cat.data(), split_cat.data() + split_cat.size(), + split_categories_.begin() + orig_size); + this->split_types_.at(nid) = FeatureType::kCategorical; + this->split_categories_segments_.at(nid).beg = orig_size; + this->split_categories_segments_.at(nid).size = split_cat.size(); +} + void RegTree::Load(dmlc::Stream* fi) { CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam)); if (!DMLC_IO_NO_ENDIAN_SWAP) { @@ -751,11 +799,24 @@ void RegTree::LoadModel(Json const& in) { auto const& default_left = get(in["default_left"]); CHECK_EQ(default_left.size(), n_nodes); + bool has_cat = get(in).find("split_type") != get(in).cend(); + std::vector split_type; + std::vector categories; + if (has_cat) { + split_type = get(in["split_type"]); + categories = get(in["categories"]); + } + + stats_.clear(); nodes_.clear(); stats_.resize(n_nodes); nodes_.resize(n_nodes); + split_types_.resize(n_nodes); + split_categories_segments_.resize(n_nodes); + + CHECK_EQ(n_nodes, split_categories_segments_.size()); for (int32_t i = 0; i < n_nodes; ++i) { auto& s = stats_[i]; s.loss_chg = get(loss_changes[i]); @@ -771,6 +832,31 @@ void RegTree::LoadModel(Json const& in) { float cond { get(conds[i]) }; bool dft_left { get(default_left[i]) }; n = Node{left, right, parent, ind, cond, dft_left}; + + if (has_cat) { + split_types_[i] = + static_cast(get(split_type[i])); + auto const& j_categories = get(categories[i]); + bst_cat_t max_cat { std::numeric_limits::min() }; + for (auto const& j_cat : j_categories) { + auto cat = common::AsCat(get(j_cat)); + max_cat = std::max(max_cat, cat); + } + size_t size = max_cat == std::numeric_limits::min() + ? 0 + : common::KCatBitField::ComputeStorageSize(max_cat); + std::vector cat_bits_storage(size); + common::CatBitField cat_bits{common::Span(cat_bits_storage)}; + for (auto const& j_cat : j_categories) { + cat_bits.Set(common::AsCat(get(j_cat))); + } + auto begin = split_categories_.size(); + split_categories_.resize(begin + cat_bits_storage.size()); + std::copy(cat_bits_storage.begin(), cat_bits_storage.end(), + split_categories_.begin() + begin); + split_categories_segments_[i].beg = begin; + split_categories_segments_[i].size = cat_bits_storage.size(); + } } deleted_nodes_.clear(); @@ -811,8 +897,11 @@ void RegTree::SaveModel(Json* p_out) const { std::vector indices(n_nodes); std::vector conds(n_nodes); std::vector default_left(n_nodes); + std::vector split_type(n_nodes); - for (int32_t i = 0; i < n_nodes; ++i) { + std::vector categories(n_nodes); + + for (bst_node_t i = 0; i < n_nodes; ++i) { auto const& s = stats_[i]; loss_changes[i] = s.loss_chg; sum_hessian[i] = s.sum_hess; @@ -826,6 +915,24 @@ void RegTree::SaveModel(Json* p_out) const { indices[i] = static_cast(n.SplitIndex()); conds[i] = n.SplitCond(); default_left[i] = n.DefaultLeft(); + + std::vector categories_temp; + // This condition is only for being compatibale with older version of XGBoost model + // that doesn't have categorical data support. + if (this->GetSplitTypes().size() == static_cast(n_nodes)) { + CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); + split_type[i] = static_cast(this->NodeSplitType(i)); + auto beg = this->GetSplitCategoriesPtr().at(i).beg; + auto size = this->GetSplitCategoriesPtr().at(i).size; + auto node_categories = this->GetSplitCategories().subspan(beg, size); + common::KCatBitField const cat_bits(node_categories); + for (size_t i = 0; i < cat_bits.Size(); ++i) { + if (cat_bits.Check(i)) { + categories_temp.emplace_back(static_cast(i)); + } + } + } + categories[i] = Array(categories_temp); } out["loss_changes"] = std::move(loss_changes); @@ -839,6 +946,12 @@ void RegTree::SaveModel(Json* p_out) const { out["split_indices"] = std::move(indices); out["split_conditions"] = std::move(conds); out["default_left"] = std::move(default_left); + + out["categories"] = categories; + + if (this->GetSplitTypes().size() == static_cast(n_nodes)) { + out["split_type"] = std::move(split_type); + } } void RegTree::FillNodeMeanValues() { diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 1dbc5fc2c..4b090db46 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -1,9 +1,10 @@ // Copyright by Contributors #include -#include #include "../helpers.h" #include "dmlc/filesystem.h" #include "xgboost/json_io.h" +#include "xgboost/tree_model.h" +#include "../../../src/common/bitfield.h" namespace xgboost { #if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines @@ -82,7 +83,7 @@ TEST(Tree, Load) { tree.Load(fi.get()); EXPECT_EQ(tree.GetDepth(1), 1); EXPECT_EQ(tree[0].SplitCond(), 0.5f); - EXPECT_EQ(tree[0].SplitIndex(), 5); + EXPECT_EQ(tree[0].SplitIndex(), 5ul); EXPECT_EQ(tree[1].LeafValue(), 0.1f); EXPECT_TRUE(tree[1].IsLeaf()); } @@ -105,6 +106,51 @@ TEST(Tree, AllocateNode) { ASSERT_TRUE(nodes.at(2).IsLeaf()); } +TEST(Tree, ExpandCategoricalFeature) { + { + RegTree tree; + tree.ExpandCategorical(0, 0, {}, true, 1.0, 2.0, 3.0, 11.0, 2.0, + /*left_sum=*/3.0, /*right_sum=*/4.0); + ASSERT_EQ(tree.GetNodes().size(), 3ul); + ASSERT_EQ(tree.GetNumLeaves(), 2); + ASSERT_EQ(tree.GetSplitTypes().size(), 3ul); + ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical); + ASSERT_EQ(tree.GetSplitTypes()[1], FeatureType::kNumerical); + ASSERT_EQ(tree.GetSplitTypes()[2], FeatureType::kNumerical); + ASSERT_EQ(tree.GetSplitCategories().size(), 0ul); + ASSERT_TRUE(std::isnan(tree[0].SplitCond())); + } + { + RegTree tree; + bst_cat_t cat = 33; + std::vector split_cats(LBitField32::ComputeStorageSize(cat+1)); + LBitField32 bitset {split_cats}; + bitset.Set(cat); + tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0, + /*left_sum=*/3.0, /*right_sum=*/4.0); + auto categories = tree.GetSplitCategories(); + auto segments = tree.GetSplitCategoriesPtr(); + auto got = categories.subspan(segments[0].beg, segments[0].size); + ASSERT_TRUE(std::equal(got.cbegin(), got.cend(), split_cats.cbegin())); + + Json out{Object()}; + tree.SaveModel(&out); + + RegTree loaded_tree; + loaded_tree.LoadModel(out); + + auto const& cat_ptr = loaded_tree.GetSplitCategoriesPtr(); + ASSERT_EQ(cat_ptr.size(), 3ul); + ASSERT_EQ(cat_ptr[0].beg, 0ul); + ASSERT_EQ(cat_ptr[0].size, 2ul); + + auto loaded_categories = loaded_tree.GetSplitCategories(); + auto loaded_root = loaded_categories.subspan(cat_ptr[0].beg, cat_ptr[0].size); + ASSERT_TRUE(std::equal(loaded_root.begin(), loaded_root.end(), split_cats.begin())); + } +} + +namespace { RegTree ConstructTree() { RegTree tree; tree.ExpandNode( @@ -123,6 +169,7 @@ RegTree ConstructTree() { /*right_sum=*/0.0f); return tree; } +} // anonymous namespace TEST(Tree, DumpJson) { auto tree = ConstructTree(); @@ -133,14 +180,14 @@ TEST(Tree, DumpJson) { while ((iter = str.find("leaf", iter + 1)) != std::string::npos) { n_leaves++; } - ASSERT_EQ(n_leaves, 4); + ASSERT_EQ(n_leaves, 4ul); size_t n_conditions = 0; iter = 0; while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) { n_conditions++; } - ASSERT_EQ(n_conditions, 3); + ASSERT_EQ(n_conditions, 3ul); fmap.PushBack(0, "feat_0", "i"); fmap.PushBack(1, "feat_1", "q"); @@ -156,7 +203,7 @@ TEST(Tree, DumpJson) { auto j_tree = Json::Load({str.c_str(), str.size()}); - ASSERT_EQ(get(j_tree["children"]).size(), 2); + ASSERT_EQ(get(j_tree["children"]).size(), 2ul); } TEST(Tree, DumpText) { @@ -168,14 +215,14 @@ TEST(Tree, DumpText) { while ((iter = str.find("leaf", iter + 1)) != std::string::npos) { n_leaves++; } - ASSERT_EQ(n_leaves, 4); + ASSERT_EQ(n_leaves, 4ul); iter = 0; size_t n_conditions = 0; while ((iter = str.find("gain", iter + 1)) != std::string::npos) { n_conditions++; } - ASSERT_EQ(n_conditions, 3); + ASSERT_EQ(n_conditions, 3ul); ASSERT_NE(str.find("[f0<0]"), std::string::npos); ASSERT_NE(str.find("[f1<1]"), std::string::npos); @@ -204,14 +251,14 @@ TEST(Tree, DumpDot) { while ((iter = str.find("leaf", iter + 1)) != std::string::npos) { n_leaves++; } - ASSERT_EQ(n_leaves, 4); + ASSERT_EQ(n_leaves, 4ul); size_t n_edges = 0; iter = 0; while ((iter = str.find("->", iter + 1)) != std::string::npos) { n_edges++; } - ASSERT_EQ(n_edges, 6); + ASSERT_EQ(n_edges, 6ul); fmap.PushBack(0, "feat_0", "i"); fmap.PushBack(1, "feat_1", "q"); @@ -238,12 +285,12 @@ TEST(Tree, JsonIO) { ASSERT_EQ(get(tparam["num_nodes"]), "3"); ASSERT_EQ(get(tparam["size_leaf_vector"]), "0"); - ASSERT_EQ(get(j_tree["left_children"]).size(), 3); - ASSERT_EQ(get(j_tree["right_children"]).size(), 3); - ASSERT_EQ(get(j_tree["parents"]).size(), 3); - ASSERT_EQ(get(j_tree["split_indices"]).size(), 3); - ASSERT_EQ(get(j_tree["split_conditions"]).size(), 3); - ASSERT_EQ(get(j_tree["default_left"]).size(), 3); + ASSERT_EQ(get(j_tree["left_children"]).size(), 3ul); + ASSERT_EQ(get(j_tree["right_children"]).size(), 3ul); + ASSERT_EQ(get(j_tree["parents"]).size(), 3ul); + ASSERT_EQ(get(j_tree["split_indices"]).size(), 3ul); + ASSERT_EQ(get(j_tree["split_conditions"]).size(), 3ul); + ASSERT_EQ(get(j_tree["default_left"]).size(), 3ul); RegTree loaded_tree; loaded_tree.LoadModel(j_tree); @@ -268,5 +315,4 @@ TEST(Tree, JsonIO) { ASSERT_EQ(loaded_tree[1].RightChild(), -1); ASSERT_TRUE(tree.Equal(loaded_tree)); } - } // namespace xgboost