Update JSON parser demo with categorical feature. (#8401)

- Parse categorical features in the Python example.
- Add tests.
- Update document.
This commit is contained in:
Jiaming Yuan 2022-10-28 20:57:43 +08:00 committed by GitHub
parent cfd2a9f872
commit a408c34558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 318 additions and 133 deletions

View File

@ -1,174 +1,281 @@
'''Demonstration for parsing JSON tree model file generated by XGBoost. The
support is experimental, output schema is subject to change in the future.
'''
import json
"""Demonstration for parsing JSON/UBJSON tree model file generated by XGBoost.
"""
import argparse
import json
from dataclasses import dataclass
from enum import IntEnum, unique
from typing import Any, Dict, List, Sequence, Union
import numpy as np
try:
import ubjson
except ImportError:
ubjson = None
ParamT = Dict[str, str]
def to_integers(data: Union[bytes, List[int]]) -> List[int]:
"""Convert a sequence of bytes to a list of Python integer"""
return [v for v in data]
@unique
class SplitType(IntEnum):
numerical = 0
categorical = 1
@dataclass
class Node:
# properties
left: int
right: int
parent: int
split_idx: int
split_cond: float
default_left: bool
split_type: SplitType
categories: List[int]
# statistic
base_weight: float
loss_chg: float
sum_hess: float
class Tree:
'''A tree built by XGBoost.'''
# Index into node array
_left = 0
_right = 1
_parent = 2
_ind = 3
_cond = 4
_default_left = 5
# Index into stat array
_loss_chg = 0
_sum_hess = 1
_base_weight = 2
"""A tree built by XGBoost."""
def __init__(self, tree_id: int, nodes, stats):
def __init__(self, tree_id: int, nodes: Sequence[Node]) -> None:
self.tree_id = tree_id
self.nodes = nodes
self.stats = stats
def loss_change(self, node_id: int):
'''Loss gain of a node.'''
return self.stats[node_id][self._loss_chg]
def loss_change(self, node_id: int) -> float:
"""Loss gain of a node."""
return self.nodes[node_id].loss_chg
def sum_hessian(self, node_id: int):
'''Sum Hessian of a node.'''
return self.stats[node_id][self._sum_hess]
def sum_hessian(self, node_id: int) -> float:
"""Sum Hessian of a node."""
return self.nodes[node_id].sum_hess
def base_weight(self, node_id: int):
'''Base weight of a node.'''
return self.stats[node_id][self._base_weight]
def base_weight(self, node_id: int) -> float:
"""Base weight of a node."""
return self.nodes[node_id].base_weight
def split_index(self, node_id: int):
'''Split feature index of node.'''
return self.nodes[node_id][self._ind]
def split_index(self, node_id: int) -> int:
"""Split feature index of node."""
return self.nodes[node_id].split_idx
def split_condition(self, node_id: int):
'''Split value of a node.'''
return self.nodes[node_id][self._cond]
def split_condition(self, node_id: int) -> float:
"""Split value of a node."""
return self.nodes[node_id].split_cond
def parent(self, node_id: int):
'''Parent ID of a node.'''
return self.nodes[node_id][self._parent]
def split_categories(self, node_id: int) -> List[int]:
"""Categories in a node."""
return self.nodes[node_id].categories
def left_child(self, node_id: int):
'''Left child ID of a node.'''
return self.nodes[node_id][self._left]
def is_categorical(self, node_id: int) -> bool:
"""Whether a node has categorical split."""
return self.nodes[node_id].split_type == SplitType.categorical
def right_child(self, node_id: int):
'''Right child ID of a node.'''
return self.nodes[node_id][self._right]
def is_numerical(self, node_id: int) -> bool:
return not self.is_categorical(node_id)
def is_leaf(self, node_id: int):
'''Whether a node is leaf.'''
return self.nodes[node_id][self._left] == -1
def parent(self, node_id: int) -> int:
"""Parent ID of a node."""
return self.nodes[node_id].parent
def is_deleted(self, node_id: int):
'''Whether a node is deleted.'''
# std::numeric_limits<uint32_t>::max()
return self.nodes[node_id][self._ind] == 4294967295
def left_child(self, node_id: int) -> int:
"""Left child ID of a node."""
return self.nodes[node_id].left
def __str__(self):
stacks = [0]
def right_child(self, node_id: int) -> int:
"""Right child ID of a node."""
return self.nodes[node_id].right
def is_leaf(self, node_id: int) -> bool:
"""Whether a node is leaf."""
return self.nodes[node_id].left == -1
def is_deleted(self, node_id: int) -> bool:
"""Whether a node is deleted."""
return self.split_index(node_id) == np.iinfo(np.uint32).max
def __str__(self) -> str:
stack = [0]
nodes = []
while stacks:
node = {}
nid = stacks.pop()
while stack:
node: Dict[str, Union[float, int, List[int]]] = {}
nid = stack.pop()
node['node id'] = nid
node['gain'] = self.loss_change(nid)
node['cover'] = self.sum_hessian(nid)
node["node id"] = nid
node["gain"] = self.loss_change(nid)
node["cover"] = self.sum_hessian(nid)
nodes.append(node)
if not self.is_leaf(nid) and not self.is_deleted(nid):
left = self.left_child(nid)
right = self.right_child(nid)
stacks.append(left)
stacks.append(right)
stack.append(left)
stack.append(right)
categories = self.split_categories(nid)
if categories:
assert self.is_categorical(nid)
node["categories"] = categories
else:
assert self.is_numerical(nid)
node["condition"] = self.split_condition(nid)
if self.is_leaf(nid):
node["weight"] = self.split_condition(nid)
string = '\n'.join(map(lambda x: ' ' + str(x), nodes))
string = "\n".join(map(lambda x: " " + str(x), nodes))
return string
class Model:
'''Gradient boosted tree model.'''
def __init__(self, model: dict):
'''Construct the Model from JSON object.
"""Gradient boosted tree model."""
def __init__(self, model: dict) -> None:
"""Construct the Model from a JSON object.
parameters
----------
m: A dictionary loaded by json
'''
# Basic property of a model
self.learner_model_shape = model['learner']['learner_model_param']
self.num_output_group = int(self.learner_model_shape['num_class'])
self.num_feature = int(self.learner_model_shape['num_feature'])
self.base_score = float(self.learner_model_shape['base_score'])
model : A dictionary loaded by json representing a XGBoost boosted tree model.
"""
# Basic properties of a model
self.learner_model_shape: ParamT = model["learner"]["learner_model_param"]
self.num_output_group = int(self.learner_model_shape["num_class"])
self.num_feature = int(self.learner_model_shape["num_feature"])
self.base_score = float(self.learner_model_shape["base_score"])
# A field encoding which output group a tree belongs
self.tree_info = model['learner']['gradient_booster']['model'][
'tree_info']
self.tree_info = model["learner"]["gradient_booster"]["model"]["tree_info"]
model_shape = model['learner']['gradient_booster']['model'][
'gbtree_model_param']
model_shape: ParamT = model["learner"]["gradient_booster"]["model"][
"gbtree_model_param"
]
# JSON representation of trees
j_trees = model['learner']['gradient_booster']['model']['trees']
j_trees = model["learner"]["gradient_booster"]["model"]["trees"]
# Load the trees
self.num_trees = int(model_shape['num_trees'])
self.leaf_size = int(model_shape['size_leaf_vector'])
self.num_trees = int(model_shape["num_trees"])
self.leaf_size = int(model_shape["size_leaf_vector"])
# Right now XGBoost doesn't support vector leaf yet
assert self.leaf_size == 0, str(self.leaf_size)
trees = []
trees: List[Tree] = []
for i in range(self.num_trees):
tree = j_trees[i]
tree_id = int(tree['id'])
tree: Dict[str, Any] = j_trees[i]
tree_id = int(tree["id"])
assert tree_id == i, (tree_id, i)
# properties
left_children = tree['left_children']
right_children = tree['right_children']
parents = tree['parents']
split_conditions = tree['split_conditions']
split_indices = tree['split_indices']
default_left = tree['default_left']
# stats
base_weights = tree['base_weights']
loss_changes = tree['loss_changes']
sum_hessian = tree['sum_hessian']
# - properties
left_children: List[int] = tree["left_children"]
right_children: List[int] = tree["right_children"]
parents: List[int] = tree["parents"]
split_conditions: List[float] = tree["split_conditions"]
split_indices: List[int] = tree["split_indices"]
# when ubjson is used, this is a byte array with each element as uint8
default_left = to_integers(tree["default_left"])
stats = []
nodes = []
# We resemble the structure used inside XGBoost, which is similar
# to adjacency list.
# - categorical features
# when ubjson is used, this is a byte array with each element as uint8
split_types = to_integers(tree["split_type"])
# categories for each node is stored in a CSR style storage with segment as
# the begin ptr and the `categories' as values.
cat_segments: List[int] = tree["categories_segments"]
cat_sizes: List[int] = tree["categories_sizes"]
# node index for categorical nodes
cat_nodes: List[int] = tree["categories_nodes"]
assert len(cat_segments) == len(cat_sizes) == len(cat_nodes)
cats = tree["categories"]
assert len(left_children) == len(split_types)
# The storage for categories is only defined for categorical nodes to
# prevent unnecessary overhead for numerical splits, we track the
# categorical node that are processed using a counter.
cat_cnt = 0
if cat_nodes:
last_cat_node = cat_nodes[cat_cnt]
else:
last_cat_node = -1
node_categories: List[List[int]] = []
for node_id in range(len(left_children)):
nodes.append([
left_children[node_id], right_children[node_id],
parents[node_id], split_indices[node_id],
split_conditions[node_id], default_left[node_id]
])
stats.append([
loss_changes[node_id], sum_hessian[node_id],
base_weights[node_id]
])
if node_id == last_cat_node:
beg = cat_segments[cat_cnt]
size = cat_sizes[cat_cnt]
end = beg + size
node_cats = cats[beg:end]
# categories are unique for each node
assert len(set(node_cats)) == len(node_cats)
cat_cnt += 1
if cat_cnt == len(cat_nodes):
last_cat_node = -1 # continue to process the rest of the nodes
else:
last_cat_node = cat_nodes[cat_cnt]
assert node_cats
node_categories.append(node_cats)
else:
# append an empty node, it's either a numerical node or a leaf.
node_categories.append([])
tree = Tree(tree_id, nodes, stats)
trees.append(tree)
# - stats
base_weights: List[float] = tree["base_weights"]
loss_changes: List[float] = tree["loss_changes"]
sum_hessian: List[float] = tree["sum_hessian"]
# Construct a list of nodes that have complete information
nodes: List[Node] = [
Node(
left_children[node_id],
right_children[node_id],
parents[node_id],
split_indices[node_id],
split_conditions[node_id],
default_left[node_id] == 1, # to boolean
SplitType(split_types[node_id]),
node_categories[node_id],
base_weights[node_id],
loss_changes[node_id],
sum_hessian[node_id],
)
for node_id in range(len(left_children))
]
pytree = Tree(tree_id, nodes)
trees.append(pytree)
self.trees = trees
def print_model(self):
def print_model(self) -> None:
for i, tree in enumerate(self.trees):
print('tree_id:', i)
print("\ntree_id:", i)
print(tree)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Demonstration for loading and printing XGBoost model.')
parser.add_argument('--model',
type=str,
required=True,
help='Path to JSON model file.')
description="Demonstration for loading XGBoost JSON/UBJSON model."
)
parser.add_argument(
"--model", type=str, required=True, help="Path to .json/.ubj model file."
)
args = parser.parse_args()
with open(args.model, 'r') as fd:
if args.model.endswith("json"):
# use json format
with open(args.model, "r") as fd:
model = json.load(fd)
elif args.model.endswith("ubj"):
if ubjson is None:
raise ImportError("ubjson is not installed.")
# use ubjson format
with open(args.model, "rb") as bfd:
model = ubjson.load(bfd)
else:
raise ValueError(
"Unexpected file extension. Supported file extension are json and ubj."
)
model = Model(model)
model.print_model()

View File

@ -245,11 +245,11 @@ JSON Schema
Another important feature of JSON format is a documented `schema
<https://json-schema.org/>`__, based on which one can easily reuse the output model from
XGBoost. Here is the initial draft of JSON schema for the output model (not
serialization, which will not be stable as noted above). It's subject to change due to
the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``.
Please notice the "weight_drop" field used in "dart" booster. XGBoost does not scale tree
leaf directly, instead it saves the weights as a separated array.
XGBoost. Here is the JSON schema for the output model (not serialization, which will not
be stable as noted above). For an example of parsing XGBoost tree model, see
``/demo/json-model``. Please notice the "weight_drop" field used in "dart" booster.
XGBoost does not scale tree leaf directly, instead it saves the weights as a separated
array.
.. include:: ../model.schema
:code: json

View File

@ -439,7 +439,7 @@ class RegTree : public Model {
* \param left_sum The sum hess of left leaf.
* \param right_sum The sum hess of right leaf.
*/
void ExpandCategorical(bst_node_t nid, unsigned split_index,
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,

View File

@ -3,6 +3,7 @@ change without notice.
"""
# pylint: disable=invalid-name,missing-function-docstring,import-error
import copy
import gc
import importlib.util
import multiprocessing
@ -477,6 +478,7 @@ def get_mq2008(
)
# pylint: disable=too-many-arguments,too-many-locals
@memory.cache
def make_categorical(
n_samples: int,
@ -484,8 +486,27 @@ def make_categorical(
n_categories: int,
onehot: bool,
sparsity: float = 0.0,
cat_ratio: float = 1.0,
) -> Tuple[ArrayLike, np.ndarray]:
"""Generate categorical features for test.
Parameters
----------
n_categories:
Number of categories for categorical features.
onehot:
Should we apply one-hot encoding to the data?
sparsity:
The ratio of the amount of missing values over the number of all entries.
cat_ratio:
The ratio of features that are categorical.
Returns
-------
X, y
"""
import pandas as pd
from pandas.api.types import is_categorical_dtype
rng = np.random.RandomState(1994)
@ -501,9 +522,10 @@ def make_categorical(
label += df.iloc[:, i]
label += 1
df = df.astype("category")
categories = np.arange(0, n_categories)
for col in df.columns:
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
df[col] = df[col].astype("category")
df[col] = df[col].cat.set_categories(categories)
if sparsity > 0.0:
@ -512,9 +534,14 @@ def make_categorical(
low=0, high=n_samples - 1, size=int(n_samples * sparsity)
)
df.iloc[index, i] = np.NaN
if is_categorical_dtype(df.dtypes[i]):
assert n_categories == np.unique(df.dtypes[i].categories).size
if onehot:
df = pd.get_dummies(df)
columns = list(df.columns)
rng.shuffle(columns)
df = df[columns]
return pd.get_dummies(df), label
return df, label

View File

@ -807,7 +807,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
this->split_types_.at(nid) = FeatureType::kNumerical;
}
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
@ -935,12 +935,15 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
if (!categories_nodes.empty()) {
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
}
// `categories_segments' is only available for categorical nodes to prevent overhead for
// numerical node. As a result, we need to track the categorical nodes we have processed
// so far.
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
if (nidx == last_cat_node) {
auto j_begin = GetElem<Integer>(categories_segments, cnt);
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
CHECK_NE(j_end - j_begin, 0) << nidx;
CHECK_GT(j_end - j_begin, 0) << nidx;
for (auto j = j_begin; j < j_end; ++j) {
auto const& category = GetElem<Integer>(categories, j);
@ -1059,6 +1062,8 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
if (has_cat) {
split_type = get<U8ArrayT const>(in["split_type"]);
}
// Initialization
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
@ -1068,6 +1073,7 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
CHECK_EQ(n_nodes, split_categories_segments.size());
// Set node
for (int32_t i = 0; i < n_nodes; ++i) {
auto& s = stats[i];
s.loss_chg = GetElem<Number>(loss_changes, i);

View File

@ -136,6 +136,7 @@ if __name__ == "__main__":
"tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/",
# demo
"demo/json-model/json_parser.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py",
"demo/guide-python/spark_estimator_examples.py",
@ -147,9 +148,13 @@ if __name__ == "__main__":
if not all(
run_mypy(path)
for path in [
# core
"python-package/xgboost/",
# demo
"demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/cat_in_the_dat.py",
# tests
"tests/python/test_data_iterator.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/ci_build/lint_python.py",

View File

@ -1,9 +1,11 @@
import os
import subprocess
import tempfile
import sys
import pytest
import xgboost
from xgboost import testing as tm
pytestmark = tm.timeout(30)
@ -138,10 +140,48 @@ def test_multioutput_reg() -> None:
subprocess.check_call(cmd)
# gpu_acceleration is not tested due to covertype dataset is being too huge.
# gamma regression is not tested as it requires running a R script first.
# aft viz is not tested due to ploting is not controled
# aft tunning is not tested due to extra dependency.
@pytest.mark.skipif(**tm.no_ubjson())
def test_json_model() -> None:
script = os.path.join(DEMO_DIR, "json-model", "json_parser.py")
def run_test(reg: xgboost.XGBRegressor) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "reg.json")
reg.save_model(path)
cmd = ["python", script, f"--model={path}"]
subprocess.check_call(cmd)
path = os.path.join(tmpdir, "reg.ubj")
reg.save_model(path)
cmd = ["python", script, f"--model={path}"]
subprocess.check_call(cmd)
# numerical
X, y = tm.make_sparse_regression(100, 10, 0.5, False)
reg = xgboost.XGBRegressor(n_estimators=2, tree_method="hist")
reg.fit(X, y)
run_test(reg)
# categorical
X, y = tm.make_categorical(
n_samples=1000,
n_features=10,
n_categories=6,
onehot=False,
sparsity=0.5,
cat_ratio=0.5,
)
reg = xgboost.XGBRegressor(
n_estimators=2, tree_method="hist", enable_categorical=True
)
reg.fit(X, y)
run_test(reg)
# - gpu_acceleration is not tested due to covertype dataset is being too huge.
# - gamma regression is not tested as it requires running a R script first.
# - aft viz is not tested due to ploting is not controlled
# - aft tunning is not tested due to extra dependency.
def test_cli_regression_demo():