Expand categorical node. (#6028)
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/categorical.h"
|
||||
|
||||
namespace xgboost {
|
||||
// register tree parameter
|
||||
@@ -662,6 +663,53 @@ bst_node_t RegTree::GetNumSplitNodes() const {
|
||||
return splits;
|
||||
}
|
||||
|
||||
void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
|
||||
bool default_left, bst_float base_weight,
|
||||
bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change,
|
||||
float sum_hess, float left_sum, float right_sum,
|
||||
bst_node_t leaf_right_child) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
auto &node = nodes_[nid];
|
||||
CHECK(node.IsLeaf());
|
||||
node.SetLeftChild(pleft);
|
||||
node.SetRightChild(pright);
|
||||
nodes_[node.LeftChild()].SetParent(nid, true);
|
||||
nodes_[node.RightChild()].SetParent(nid, false);
|
||||
node.SetSplit(split_index, split_value, default_left);
|
||||
|
||||
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
|
||||
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
|
||||
|
||||
this->Stat(nid) = {loss_change, sum_hess, base_weight};
|
||||
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
|
||||
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
|
||||
|
||||
this->split_types_.at(nid) = FeatureType::kNumerical;
|
||||
}
|
||||
|
||||
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||
common::Span<uint32_t> split_cat, bool default_left,
|
||||
bst_float base_weight,
|
||||
bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight,
|
||||
bst_float loss_change, float sum_hess,
|
||||
float left_sum, float right_sum) {
|
||||
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
|
||||
default_left, base_weight,
|
||||
left_leaf_weight, right_leaf_weight, loss_change, sum_hess,
|
||||
left_sum, right_sum);
|
||||
|
||||
size_t orig_size = split_categories_.size();
|
||||
this->split_categories_.resize(orig_size + split_cat.size());
|
||||
std::copy(split_cat.data(), split_cat.data() + split_cat.size(),
|
||||
split_categories_.begin() + orig_size);
|
||||
this->split_types_.at(nid) = FeatureType::kCategorical;
|
||||
this->split_categories_segments_.at(nid).beg = orig_size;
|
||||
this->split_categories_segments_.at(nid).size = split_cat.size();
|
||||
}
|
||||
|
||||
void RegTree::Load(dmlc::Stream* fi) {
|
||||
CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam));
|
||||
if (!DMLC_IO_NO_ENDIAN_SWAP) {
|
||||
@@ -751,11 +799,24 @@ void RegTree::LoadModel(Json const& in) {
|
||||
auto const& default_left = get<Array const>(in["default_left"]);
|
||||
CHECK_EQ(default_left.size(), n_nodes);
|
||||
|
||||
bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
|
||||
std::vector<Json> split_type;
|
||||
std::vector<Json> categories;
|
||||
if (has_cat) {
|
||||
split_type = get<Array const>(in["split_type"]);
|
||||
categories = get<Array const>(in["categories"]);
|
||||
}
|
||||
|
||||
|
||||
stats_.clear();
|
||||
nodes_.clear();
|
||||
|
||||
stats_.resize(n_nodes);
|
||||
nodes_.resize(n_nodes);
|
||||
split_types_.resize(n_nodes);
|
||||
split_categories_segments_.resize(n_nodes);
|
||||
|
||||
CHECK_EQ(n_nodes, split_categories_segments_.size());
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
auto& s = stats_[i];
|
||||
s.loss_chg = get<Number const>(loss_changes[i]);
|
||||
@@ -771,6 +832,31 @@ void RegTree::LoadModel(Json const& in) {
|
||||
float cond { get<Number const>(conds[i]) };
|
||||
bool dft_left { get<Boolean const>(default_left[i]) };
|
||||
n = Node{left, right, parent, ind, cond, dft_left};
|
||||
|
||||
if (has_cat) {
|
||||
split_types_[i] =
|
||||
static_cast<FeatureType>(get<Integer const>(split_type[i]));
|
||||
auto const& j_categories = get<Array const>(categories[i]);
|
||||
bst_cat_t max_cat { std::numeric_limits<bst_cat_t>::min() };
|
||||
for (auto const& j_cat : j_categories) {
|
||||
auto cat = common::AsCat(get<Integer const>(j_cat));
|
||||
max_cat = std::max(max_cat, cat);
|
||||
}
|
||||
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
|
||||
? 0
|
||||
: common::KCatBitField::ComputeStorageSize(max_cat);
|
||||
std::vector<uint32_t> cat_bits_storage(size);
|
||||
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
|
||||
for (auto const& j_cat : j_categories) {
|
||||
cat_bits.Set(common::AsCat(get<Integer const>(j_cat)));
|
||||
}
|
||||
auto begin = split_categories_.size();
|
||||
split_categories_.resize(begin + cat_bits_storage.size());
|
||||
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
|
||||
split_categories_.begin() + begin);
|
||||
split_categories_segments_[i].beg = begin;
|
||||
split_categories_segments_[i].size = cat_bits_storage.size();
|
||||
}
|
||||
}
|
||||
|
||||
deleted_nodes_.clear();
|
||||
@@ -811,8 +897,11 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
std::vector<Json> indices(n_nodes);
|
||||
std::vector<Json> conds(n_nodes);
|
||||
std::vector<Json> default_left(n_nodes);
|
||||
std::vector<Json> split_type(n_nodes);
|
||||
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
std::vector<Json> categories(n_nodes);
|
||||
|
||||
for (bst_node_t i = 0; i < n_nodes; ++i) {
|
||||
auto const& s = stats_[i];
|
||||
loss_changes[i] = s.loss_chg;
|
||||
sum_hessian[i] = s.sum_hess;
|
||||
@@ -826,6 +915,24 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
indices[i] = static_cast<I>(n.SplitIndex());
|
||||
conds[i] = n.SplitCond();
|
||||
default_left[i] = n.DefaultLeft();
|
||||
|
||||
std::vector<Json> categories_temp;
|
||||
// This condition is only for being compatibale with older version of XGBoost model
|
||||
// that doesn't have categorical data support.
|
||||
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
|
||||
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
|
||||
split_type[i] = static_cast<I>(this->NodeSplitType(i));
|
||||
auto beg = this->GetSplitCategoriesPtr().at(i).beg;
|
||||
auto size = this->GetSplitCategoriesPtr().at(i).size;
|
||||
auto node_categories = this->GetSplitCategories().subspan(beg, size);
|
||||
common::KCatBitField const cat_bits(node_categories);
|
||||
for (size_t i = 0; i < cat_bits.Size(); ++i) {
|
||||
if (cat_bits.Check(i)) {
|
||||
categories_temp.emplace_back(static_cast<Integer::Int>(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
categories[i] = Array(categories_temp);
|
||||
}
|
||||
|
||||
out["loss_changes"] = std::move(loss_changes);
|
||||
@@ -839,6 +946,12 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
out["split_indices"] = std::move(indices);
|
||||
out["split_conditions"] = std::move(conds);
|
||||
out["default_left"] = std::move(default_left);
|
||||
|
||||
out["categories"] = categories;
|
||||
|
||||
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
|
||||
out["split_type"] = std::move(split_type);
|
||||
}
|
||||
}
|
||||
|
||||
void RegTree::FillNodeMeanValues() {
|
||||
|
||||
Reference in New Issue
Block a user