Update JSON parser demo with categorical feature. (#8401)
- Parse categorical features in the Python example. - Add tests. - Update document.
This commit is contained in:
@@ -807,7 +807,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
|
||||
this->split_types_.at(nid) = FeatureType::kNumerical;
|
||||
}
|
||||
|
||||
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
||||
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
||||
common::Span<const uint32_t> split_cat, bool default_left,
|
||||
bst_float base_weight, bst_float left_leaf_weight,
|
||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||
@@ -935,12 +935,15 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
||||
if (!categories_nodes.empty()) {
|
||||
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
|
||||
}
|
||||
// `categories_segments' is only available for categorical nodes to prevent overhead for
|
||||
// numerical node. As a result, we need to track the categorical nodes we have processed
|
||||
// so far.
|
||||
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
|
||||
if (nidx == last_cat_node) {
|
||||
auto j_begin = GetElem<Integer>(categories_segments, cnt);
|
||||
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
|
||||
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
|
||||
CHECK_NE(j_end - j_begin, 0) << nidx;
|
||||
CHECK_GT(j_end - j_begin, 0) << nidx;
|
||||
|
||||
for (auto j = j_begin; j < j_end; ++j) {
|
||||
auto const& category = GetElem<Integer>(categories, j);
|
||||
@@ -1059,6 +1062,8 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
|
||||
if (has_cat) {
|
||||
split_type = get<U8ArrayT const>(in["split_type"]);
|
||||
}
|
||||
|
||||
// Initialization
|
||||
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
|
||||
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
|
||||
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
|
||||
@@ -1068,6 +1073,7 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
|
||||
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
|
||||
CHECK_EQ(n_nodes, split_categories_segments.size());
|
||||
|
||||
// Set node
|
||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||
auto& s = stats[i];
|
||||
s.loss_chg = GetElem<Number>(loss_changes, i);
|
||||
|
||||
Reference in New Issue
Block a user