Update JSON parser demo with categorical feature. (#8401)

- Parse categorical features in the Python example.
- Add tests.
- Update document.
This commit is contained in:
Jiaming Yuan
2022-10-28 20:57:43 +08:00
committed by GitHub
parent cfd2a9f872
commit a408c34558
7 changed files with 318 additions and 133 deletions

View File

@@ -807,7 +807,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
this->split_types_.at(nid) = FeatureType::kNumerical;
}
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
@@ -935,12 +935,15 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
if (!categories_nodes.empty()) {
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
}
// `categories_segments' is only available for categorical nodes to prevent overhead for
// numerical node. As a result, we need to track the categorical nodes we have processed
// so far.
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
if (nidx == last_cat_node) {
auto j_begin = GetElem<Integer>(categories_segments, cnt);
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
CHECK_NE(j_end - j_begin, 0) << nidx;
CHECK_GT(j_end - j_begin, 0) << nidx;
for (auto j = j_begin; j < j_end; ++j) {
auto const& category = GetElem<Integer>(categories, j);
@@ -1059,6 +1062,8 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
if (has_cat) {
split_type = get<U8ArrayT const>(in["split_type"]);
}
// Initialization
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
@@ -1068,6 +1073,7 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
CHECK_EQ(n_nodes, split_categories_segments.size());
// Set node
for (int32_t i = 0; i < n_nodes; ++i) {
auto& s = stats[i];
s.loss_chg = GetElem<Number>(loss_changes, i);