Update JSON parser demo with categorical feature. (#8401)
- Parse categorical features in the Python example. - Add tests. - Update document.
This commit is contained in:
parent
cfd2a9f872
commit
a408c34558
@ -1,174 +1,281 @@
|
|||||||
'''Demonstration for parsing JSON tree model file generated by XGBoost. The
|
"""Demonstration for parsing JSON/UBJSON tree model file generated by XGBoost.
|
||||||
support is experimental, output schema is subject to change in the future.
|
"""
|
||||||
'''
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum, unique
|
||||||
|
from typing import Any, Dict, List, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ubjson
|
||||||
|
except ImportError:
|
||||||
|
ubjson = None
|
||||||
|
|
||||||
|
|
||||||
|
ParamT = Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
def to_integers(data: Union[bytes, List[int]]) -> List[int]:
|
||||||
|
"""Convert a sequence of bytes to a list of Python integer"""
|
||||||
|
return [v for v in data]
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class SplitType(IntEnum):
|
||||||
|
numerical = 0
|
||||||
|
categorical = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Node:
|
||||||
|
# properties
|
||||||
|
left: int
|
||||||
|
right: int
|
||||||
|
parent: int
|
||||||
|
split_idx: int
|
||||||
|
split_cond: float
|
||||||
|
default_left: bool
|
||||||
|
split_type: SplitType
|
||||||
|
categories: List[int]
|
||||||
|
# statistic
|
||||||
|
base_weight: float
|
||||||
|
loss_chg: float
|
||||||
|
sum_hess: float
|
||||||
|
|
||||||
|
|
||||||
class Tree:
|
class Tree:
|
||||||
'''A tree built by XGBoost.'''
|
"""A tree built by XGBoost."""
|
||||||
# Index into node array
|
|
||||||
_left = 0
|
|
||||||
_right = 1
|
|
||||||
_parent = 2
|
|
||||||
_ind = 3
|
|
||||||
_cond = 4
|
|
||||||
_default_left = 5
|
|
||||||
# Index into stat array
|
|
||||||
_loss_chg = 0
|
|
||||||
_sum_hess = 1
|
|
||||||
_base_weight = 2
|
|
||||||
|
|
||||||
def __init__(self, tree_id: int, nodes, stats):
|
def __init__(self, tree_id: int, nodes: Sequence[Node]) -> None:
|
||||||
self.tree_id = tree_id
|
self.tree_id = tree_id
|
||||||
self.nodes = nodes
|
self.nodes = nodes
|
||||||
self.stats = stats
|
|
||||||
|
|
||||||
def loss_change(self, node_id: int):
|
def loss_change(self, node_id: int) -> float:
|
||||||
'''Loss gain of a node.'''
|
"""Loss gain of a node."""
|
||||||
return self.stats[node_id][self._loss_chg]
|
return self.nodes[node_id].loss_chg
|
||||||
|
|
||||||
def sum_hessian(self, node_id: int):
|
def sum_hessian(self, node_id: int) -> float:
|
||||||
'''Sum Hessian of a node.'''
|
"""Sum Hessian of a node."""
|
||||||
return self.stats[node_id][self._sum_hess]
|
return self.nodes[node_id].sum_hess
|
||||||
|
|
||||||
def base_weight(self, node_id: int):
|
def base_weight(self, node_id: int) -> float:
|
||||||
'''Base weight of a node.'''
|
"""Base weight of a node."""
|
||||||
return self.stats[node_id][self._base_weight]
|
return self.nodes[node_id].base_weight
|
||||||
|
|
||||||
def split_index(self, node_id: int):
|
def split_index(self, node_id: int) -> int:
|
||||||
'''Split feature index of node.'''
|
"""Split feature index of node."""
|
||||||
return self.nodes[node_id][self._ind]
|
return self.nodes[node_id].split_idx
|
||||||
|
|
||||||
def split_condition(self, node_id: int):
|
def split_condition(self, node_id: int) -> float:
|
||||||
'''Split value of a node.'''
|
"""Split value of a node."""
|
||||||
return self.nodes[node_id][self._cond]
|
return self.nodes[node_id].split_cond
|
||||||
|
|
||||||
def parent(self, node_id: int):
|
def split_categories(self, node_id: int) -> List[int]:
|
||||||
'''Parent ID of a node.'''
|
"""Categories in a node."""
|
||||||
return self.nodes[node_id][self._parent]
|
return self.nodes[node_id].categories
|
||||||
|
|
||||||
def left_child(self, node_id: int):
|
def is_categorical(self, node_id: int) -> bool:
|
||||||
'''Left child ID of a node.'''
|
"""Whether a node has categorical split."""
|
||||||
return self.nodes[node_id][self._left]
|
return self.nodes[node_id].split_type == SplitType.categorical
|
||||||
|
|
||||||
def right_child(self, node_id: int):
|
def is_numerical(self, node_id: int) -> bool:
|
||||||
'''Right child ID of a node.'''
|
return not self.is_categorical(node_id)
|
||||||
return self.nodes[node_id][self._right]
|
|
||||||
|
|
||||||
def is_leaf(self, node_id: int):
|
def parent(self, node_id: int) -> int:
|
||||||
'''Whether a node is leaf.'''
|
"""Parent ID of a node."""
|
||||||
return self.nodes[node_id][self._left] == -1
|
return self.nodes[node_id].parent
|
||||||
|
|
||||||
def is_deleted(self, node_id: int):
|
def left_child(self, node_id: int) -> int:
|
||||||
'''Whether a node is deleted.'''
|
"""Left child ID of a node."""
|
||||||
# std::numeric_limits<uint32_t>::max()
|
return self.nodes[node_id].left
|
||||||
return self.nodes[node_id][self._ind] == 4294967295
|
|
||||||
|
|
||||||
def __str__(self):
|
def right_child(self, node_id: int) -> int:
|
||||||
stacks = [0]
|
"""Right child ID of a node."""
|
||||||
|
return self.nodes[node_id].right
|
||||||
|
|
||||||
|
def is_leaf(self, node_id: int) -> bool:
|
||||||
|
"""Whether a node is leaf."""
|
||||||
|
return self.nodes[node_id].left == -1
|
||||||
|
|
||||||
|
def is_deleted(self, node_id: int) -> bool:
|
||||||
|
"""Whether a node is deleted."""
|
||||||
|
return self.split_index(node_id) == np.iinfo(np.uint32).max
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
stack = [0]
|
||||||
nodes = []
|
nodes = []
|
||||||
while stacks:
|
while stack:
|
||||||
node = {}
|
node: Dict[str, Union[float, int, List[int]]] = {}
|
||||||
nid = stacks.pop()
|
nid = stack.pop()
|
||||||
|
|
||||||
node['node id'] = nid
|
node["node id"] = nid
|
||||||
node['gain'] = self.loss_change(nid)
|
node["gain"] = self.loss_change(nid)
|
||||||
node['cover'] = self.sum_hessian(nid)
|
node["cover"] = self.sum_hessian(nid)
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
|
|
||||||
if not self.is_leaf(nid) and not self.is_deleted(nid):
|
if not self.is_leaf(nid) and not self.is_deleted(nid):
|
||||||
left = self.left_child(nid)
|
left = self.left_child(nid)
|
||||||
right = self.right_child(nid)
|
right = self.right_child(nid)
|
||||||
stacks.append(left)
|
stack.append(left)
|
||||||
stacks.append(right)
|
stack.append(right)
|
||||||
|
categories = self.split_categories(nid)
|
||||||
|
if categories:
|
||||||
|
assert self.is_categorical(nid)
|
||||||
|
node["categories"] = categories
|
||||||
|
else:
|
||||||
|
assert self.is_numerical(nid)
|
||||||
|
node["condition"] = self.split_condition(nid)
|
||||||
|
if self.is_leaf(nid):
|
||||||
|
node["weight"] = self.split_condition(nid)
|
||||||
|
|
||||||
string = '\n'.join(map(lambda x: ' ' + str(x), nodes))
|
string = "\n".join(map(lambda x: " " + str(x), nodes))
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
'''Gradient boosted tree model.'''
|
"""Gradient boosted tree model."""
|
||||||
def __init__(self, model: dict):
|
|
||||||
'''Construct the Model from JSON object.
|
|
||||||
|
|
||||||
parameters
|
def __init__(self, model: dict) -> None:
|
||||||
----------
|
"""Construct the Model from a JSON object.
|
||||||
m: A dictionary loaded by json
|
|
||||||
'''
|
parameters
|
||||||
# Basic property of a model
|
----------
|
||||||
self.learner_model_shape = model['learner']['learner_model_param']
|
model : A dictionary loaded by json representing a XGBoost boosted tree model.
|
||||||
self.num_output_group = int(self.learner_model_shape['num_class'])
|
"""
|
||||||
self.num_feature = int(self.learner_model_shape['num_feature'])
|
# Basic properties of a model
|
||||||
self.base_score = float(self.learner_model_shape['base_score'])
|
self.learner_model_shape: ParamT = model["learner"]["learner_model_param"]
|
||||||
|
self.num_output_group = int(self.learner_model_shape["num_class"])
|
||||||
|
self.num_feature = int(self.learner_model_shape["num_feature"])
|
||||||
|
self.base_score = float(self.learner_model_shape["base_score"])
|
||||||
# A field encoding which output group a tree belongs
|
# A field encoding which output group a tree belongs
|
||||||
self.tree_info = model['learner']['gradient_booster']['model'][
|
self.tree_info = model["learner"]["gradient_booster"]["model"]["tree_info"]
|
||||||
'tree_info']
|
|
||||||
|
|
||||||
model_shape = model['learner']['gradient_booster']['model'][
|
model_shape: ParamT = model["learner"]["gradient_booster"]["model"][
|
||||||
'gbtree_model_param']
|
"gbtree_model_param"
|
||||||
|
]
|
||||||
|
|
||||||
# JSON representation of trees
|
# JSON representation of trees
|
||||||
j_trees = model['learner']['gradient_booster']['model']['trees']
|
j_trees = model["learner"]["gradient_booster"]["model"]["trees"]
|
||||||
|
|
||||||
# Load the trees
|
# Load the trees
|
||||||
self.num_trees = int(model_shape['num_trees'])
|
self.num_trees = int(model_shape["num_trees"])
|
||||||
self.leaf_size = int(model_shape['size_leaf_vector'])
|
self.leaf_size = int(model_shape["size_leaf_vector"])
|
||||||
# Right now XGBoost doesn't support vector leaf yet
|
# Right now XGBoost doesn't support vector leaf yet
|
||||||
assert self.leaf_size == 0, str(self.leaf_size)
|
assert self.leaf_size == 0, str(self.leaf_size)
|
||||||
|
|
||||||
trees = []
|
trees: List[Tree] = []
|
||||||
for i in range(self.num_trees):
|
for i in range(self.num_trees):
|
||||||
tree = j_trees[i]
|
tree: Dict[str, Any] = j_trees[i]
|
||||||
tree_id = int(tree['id'])
|
tree_id = int(tree["id"])
|
||||||
assert tree_id == i, (tree_id, i)
|
assert tree_id == i, (tree_id, i)
|
||||||
# properties
|
# - properties
|
||||||
left_children = tree['left_children']
|
left_children: List[int] = tree["left_children"]
|
||||||
right_children = tree['right_children']
|
right_children: List[int] = tree["right_children"]
|
||||||
parents = tree['parents']
|
parents: List[int] = tree["parents"]
|
||||||
split_conditions = tree['split_conditions']
|
split_conditions: List[float] = tree["split_conditions"]
|
||||||
split_indices = tree['split_indices']
|
split_indices: List[int] = tree["split_indices"]
|
||||||
default_left = tree['default_left']
|
# when ubjson is used, this is a byte array with each element as uint8
|
||||||
# stats
|
default_left = to_integers(tree["default_left"])
|
||||||
base_weights = tree['base_weights']
|
|
||||||
loss_changes = tree['loss_changes']
|
|
||||||
sum_hessian = tree['sum_hessian']
|
|
||||||
|
|
||||||
stats = []
|
# - categorical features
|
||||||
nodes = []
|
# when ubjson is used, this is a byte array with each element as uint8
|
||||||
# We resemble the structure used inside XGBoost, which is similar
|
split_types = to_integers(tree["split_type"])
|
||||||
# to adjacency list.
|
# categories for each node is stored in a CSR style storage with segment as
|
||||||
|
# the begin ptr and the `categories' as values.
|
||||||
|
cat_segments: List[int] = tree["categories_segments"]
|
||||||
|
cat_sizes: List[int] = tree["categories_sizes"]
|
||||||
|
# node index for categorical nodes
|
||||||
|
cat_nodes: List[int] = tree["categories_nodes"]
|
||||||
|
assert len(cat_segments) == len(cat_sizes) == len(cat_nodes)
|
||||||
|
cats = tree["categories"]
|
||||||
|
assert len(left_children) == len(split_types)
|
||||||
|
|
||||||
|
# The storage for categories is only defined for categorical nodes to
|
||||||
|
# prevent unnecessary overhead for numerical splits, we track the
|
||||||
|
# categorical node that are processed using a counter.
|
||||||
|
cat_cnt = 0
|
||||||
|
if cat_nodes:
|
||||||
|
last_cat_node = cat_nodes[cat_cnt]
|
||||||
|
else:
|
||||||
|
last_cat_node = -1
|
||||||
|
node_categories: List[List[int]] = []
|
||||||
for node_id in range(len(left_children)):
|
for node_id in range(len(left_children)):
|
||||||
nodes.append([
|
if node_id == last_cat_node:
|
||||||
left_children[node_id], right_children[node_id],
|
beg = cat_segments[cat_cnt]
|
||||||
parents[node_id], split_indices[node_id],
|
size = cat_sizes[cat_cnt]
|
||||||
split_conditions[node_id], default_left[node_id]
|
end = beg + size
|
||||||
])
|
node_cats = cats[beg:end]
|
||||||
stats.append([
|
# categories are unique for each node
|
||||||
loss_changes[node_id], sum_hessian[node_id],
|
assert len(set(node_cats)) == len(node_cats)
|
||||||
base_weights[node_id]
|
cat_cnt += 1
|
||||||
])
|
if cat_cnt == len(cat_nodes):
|
||||||
|
last_cat_node = -1 # continue to process the rest of the nodes
|
||||||
|
else:
|
||||||
|
last_cat_node = cat_nodes[cat_cnt]
|
||||||
|
assert node_cats
|
||||||
|
node_categories.append(node_cats)
|
||||||
|
else:
|
||||||
|
# append an empty node, it's either a numerical node or a leaf.
|
||||||
|
node_categories.append([])
|
||||||
|
|
||||||
tree = Tree(tree_id, nodes, stats)
|
# - stats
|
||||||
trees.append(tree)
|
base_weights: List[float] = tree["base_weights"]
|
||||||
|
loss_changes: List[float] = tree["loss_changes"]
|
||||||
|
sum_hessian: List[float] = tree["sum_hessian"]
|
||||||
|
|
||||||
|
# Construct a list of nodes that have complete information
|
||||||
|
nodes: List[Node] = [
|
||||||
|
Node(
|
||||||
|
left_children[node_id],
|
||||||
|
right_children[node_id],
|
||||||
|
parents[node_id],
|
||||||
|
split_indices[node_id],
|
||||||
|
split_conditions[node_id],
|
||||||
|
default_left[node_id] == 1, # to boolean
|
||||||
|
SplitType(split_types[node_id]),
|
||||||
|
node_categories[node_id],
|
||||||
|
base_weights[node_id],
|
||||||
|
loss_changes[node_id],
|
||||||
|
sum_hessian[node_id],
|
||||||
|
)
|
||||||
|
for node_id in range(len(left_children))
|
||||||
|
]
|
||||||
|
|
||||||
|
pytree = Tree(tree_id, nodes)
|
||||||
|
trees.append(pytree)
|
||||||
|
|
||||||
self.trees = trees
|
self.trees = trees
|
||||||
|
|
||||||
def print_model(self):
|
def print_model(self) -> None:
|
||||||
for i, tree in enumerate(self.trees):
|
for i, tree in enumerate(self.trees):
|
||||||
print('tree_id:', i)
|
print("\ntree_id:", i)
|
||||||
print(tree)
|
print(tree)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Demonstration for loading and printing XGBoost model.')
|
description="Demonstration for loading XGBoost JSON/UBJSON model."
|
||||||
parser.add_argument('--model',
|
)
|
||||||
type=str,
|
parser.add_argument(
|
||||||
required=True,
|
"--model", type=str, required=True, help="Path to .json/.ubj model file."
|
||||||
help='Path to JSON model file.')
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open(args.model, 'r') as fd:
|
if args.model.endswith("json"):
|
||||||
model = json.load(fd)
|
# use json format
|
||||||
|
with open(args.model, "r") as fd:
|
||||||
|
model = json.load(fd)
|
||||||
|
elif args.model.endswith("ubj"):
|
||||||
|
if ubjson is None:
|
||||||
|
raise ImportError("ubjson is not installed.")
|
||||||
|
# use ubjson format
|
||||||
|
with open(args.model, "rb") as bfd:
|
||||||
|
model = ubjson.load(bfd)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unexpected file extension. Supported file extension are json and ubj."
|
||||||
|
)
|
||||||
model = Model(model)
|
model = Model(model)
|
||||||
model.print_model()
|
model.print_model()
|
||||||
|
|||||||
@ -245,11 +245,11 @@ JSON Schema
|
|||||||
|
|
||||||
Another important feature of JSON format is a documented `schema
|
Another important feature of JSON format is a documented `schema
|
||||||
<https://json-schema.org/>`__, based on which one can easily reuse the output model from
|
<https://json-schema.org/>`__, based on which one can easily reuse the output model from
|
||||||
XGBoost. Here is the initial draft of JSON schema for the output model (not
|
XGBoost. Here is the JSON schema for the output model (not serialization, which will not
|
||||||
serialization, which will not be stable as noted above). It's subject to change due to
|
be stable as noted above). For an example of parsing XGBoost tree model, see
|
||||||
the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``.
|
``/demo/json-model``. Please notice the "weight_drop" field used in "dart" booster.
|
||||||
Please notice the "weight_drop" field used in "dart" booster. XGBoost does not scale tree
|
XGBoost does not scale tree leaf directly, instead it saves the weights as a separated
|
||||||
leaf directly, instead it saves the weights as a separated array.
|
array.
|
||||||
|
|
||||||
.. include:: ../model.schema
|
.. include:: ../model.schema
|
||||||
:code: json
|
:code: json
|
||||||
|
|||||||
@ -439,7 +439,7 @@ class RegTree : public Model {
|
|||||||
* \param left_sum The sum hess of left leaf.
|
* \param left_sum The sum hess of left leaf.
|
||||||
* \param right_sum The sum hess of right leaf.
|
* \param right_sum The sum hess of right leaf.
|
||||||
*/
|
*/
|
||||||
void ExpandCategorical(bst_node_t nid, unsigned split_index,
|
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
||||||
common::Span<const uint32_t> split_cat, bool default_left,
|
common::Span<const uint32_t> split_cat, bool default_left,
|
||||||
bst_float base_weight, bst_float left_leaf_weight,
|
bst_float base_weight, bst_float left_leaf_weight,
|
||||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ change without notice.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name,missing-function-docstring,import-error
|
# pylint: disable=invalid-name,missing-function-docstring,import-error
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@ -477,6 +478,7 @@ def get_mq2008(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-arguments,too-many-locals
|
||||||
@memory.cache
|
@memory.cache
|
||||||
def make_categorical(
|
def make_categorical(
|
||||||
n_samples: int,
|
n_samples: int,
|
||||||
@ -484,8 +486,27 @@ def make_categorical(
|
|||||||
n_categories: int,
|
n_categories: int,
|
||||||
onehot: bool,
|
onehot: bool,
|
||||||
sparsity: float = 0.0,
|
sparsity: float = 0.0,
|
||||||
|
cat_ratio: float = 1.0,
|
||||||
) -> Tuple[ArrayLike, np.ndarray]:
|
) -> Tuple[ArrayLike, np.ndarray]:
|
||||||
|
"""Generate categorical features for test.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_categories:
|
||||||
|
Number of categories for categorical features.
|
||||||
|
onehot:
|
||||||
|
Should we apply one-hot encoding to the data?
|
||||||
|
sparsity:
|
||||||
|
The ratio of the amount of missing values over the number of all entries.
|
||||||
|
cat_ratio:
|
||||||
|
The ratio of features that are categorical.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
X, y
|
||||||
|
"""
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pandas.api.types import is_categorical_dtype
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
@ -501,10 +522,11 @@ def make_categorical(
|
|||||||
label += df.iloc[:, i]
|
label += df.iloc[:, i]
|
||||||
label += 1
|
label += 1
|
||||||
|
|
||||||
df = df.astype("category")
|
|
||||||
categories = np.arange(0, n_categories)
|
categories = np.arange(0, n_categories)
|
||||||
for col in df.columns:
|
for col in df.columns:
|
||||||
df[col] = df[col].cat.set_categories(categories)
|
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
|
||||||
|
df[col] = df[col].astype("category")
|
||||||
|
df[col] = df[col].cat.set_categories(categories)
|
||||||
|
|
||||||
if sparsity > 0.0:
|
if sparsity > 0.0:
|
||||||
for i in range(n_features):
|
for i in range(n_features):
|
||||||
@ -512,9 +534,14 @@ def make_categorical(
|
|||||||
low=0, high=n_samples - 1, size=int(n_samples * sparsity)
|
low=0, high=n_samples - 1, size=int(n_samples * sparsity)
|
||||||
)
|
)
|
||||||
df.iloc[index, i] = np.NaN
|
df.iloc[index, i] = np.NaN
|
||||||
assert n_categories == np.unique(df.dtypes[i].categories).size
|
if is_categorical_dtype(df.dtypes[i]):
|
||||||
|
assert n_categories == np.unique(df.dtypes[i].categories).size
|
||||||
|
|
||||||
if onehot:
|
if onehot:
|
||||||
|
df = pd.get_dummies(df)
|
||||||
|
columns = list(df.columns)
|
||||||
|
rng.shuffle(columns)
|
||||||
|
df = df[columns]
|
||||||
return pd.get_dummies(df), label
|
return pd.get_dummies(df), label
|
||||||
return df, label
|
return df, label
|
||||||
|
|
||||||
|
|||||||
@ -807,7 +807,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
|
|||||||
this->split_types_.at(nid) = FeatureType::kNumerical;
|
this->split_types_.at(nid) = FeatureType::kNumerical;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
|
void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
|
||||||
common::Span<const uint32_t> split_cat, bool default_left,
|
common::Span<const uint32_t> split_cat, bool default_left,
|
||||||
bst_float base_weight, bst_float left_leaf_weight,
|
bst_float base_weight, bst_float left_leaf_weight,
|
||||||
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
|
||||||
@ -935,12 +935,15 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
|
|||||||
if (!categories_nodes.empty()) {
|
if (!categories_nodes.empty()) {
|
||||||
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
|
last_cat_node = GetElem<Integer>(categories_nodes, cnt);
|
||||||
}
|
}
|
||||||
|
// `categories_segments' is only available for categorical nodes to prevent overhead for
|
||||||
|
// numerical node. As a result, we need to track the categorical nodes we have processed
|
||||||
|
// so far.
|
||||||
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
|
for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) {
|
||||||
if (nidx == last_cat_node) {
|
if (nidx == last_cat_node) {
|
||||||
auto j_begin = GetElem<Integer>(categories_segments, cnt);
|
auto j_begin = GetElem<Integer>(categories_segments, cnt);
|
||||||
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
|
auto j_end = GetElem<Integer>(categories_sizes, cnt) + j_begin;
|
||||||
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
|
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
|
||||||
CHECK_NE(j_end - j_begin, 0) << nidx;
|
CHECK_GT(j_end - j_begin, 0) << nidx;
|
||||||
|
|
||||||
for (auto j = j_begin; j < j_end; ++j) {
|
for (auto j = j_begin; j < j_end; ++j) {
|
||||||
auto const& category = GetElem<Integer>(categories, j);
|
auto const& category = GetElem<Integer>(categories, j);
|
||||||
@ -1059,6 +1062,8 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
|
|||||||
if (has_cat) {
|
if (has_cat) {
|
||||||
split_type = get<U8ArrayT const>(in["split_type"]);
|
split_type = get<U8ArrayT const>(in["split_type"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialization
|
||||||
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
|
stats = std::remove_reference_t<decltype(stats)>(n_nodes);
|
||||||
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
|
nodes = std::remove_reference_t<decltype(nodes)>(n_nodes);
|
||||||
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
|
split_types = std::remove_reference_t<decltype(split_types)>(n_nodes);
|
||||||
@ -1068,6 +1073,7 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
|
|||||||
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
|
static_assert(std::is_floating_point<decltype(GetElem<Number>(loss_changes, 0))>::value, "");
|
||||||
CHECK_EQ(n_nodes, split_categories_segments.size());
|
CHECK_EQ(n_nodes, split_categories_segments.size());
|
||||||
|
|
||||||
|
// Set node
|
||||||
for (int32_t i = 0; i < n_nodes; ++i) {
|
for (int32_t i = 0; i < n_nodes; ++i) {
|
||||||
auto& s = stats[i];
|
auto& s = stats[i];
|
||||||
s.loss_chg = GetElem<Number>(loss_changes, i);
|
s.loss_chg = GetElem<Number>(loss_changes, i);
|
||||||
|
|||||||
@ -136,6 +136,7 @@ if __name__ == "__main__":
|
|||||||
"tests/test_distributed/test_with_spark/",
|
"tests/test_distributed/test_with_spark/",
|
||||||
"tests/test_distributed/test_gpu_with_spark/",
|
"tests/test_distributed/test_gpu_with_spark/",
|
||||||
# demo
|
# demo
|
||||||
|
"demo/json-model/json_parser.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
"demo/guide-python/categorical.py",
|
"demo/guide-python/categorical.py",
|
||||||
"demo/guide-python/spark_estimator_examples.py",
|
"demo/guide-python/spark_estimator_examples.py",
|
||||||
@ -147,9 +148,13 @@ if __name__ == "__main__":
|
|||||||
if not all(
|
if not all(
|
||||||
run_mypy(path)
|
run_mypy(path)
|
||||||
for path in [
|
for path in [
|
||||||
|
# core
|
||||||
"python-package/xgboost/",
|
"python-package/xgboost/",
|
||||||
|
# demo
|
||||||
|
"demo/json-model/json_parser.py",
|
||||||
"demo/guide-python/external_memory.py",
|
"demo/guide-python/external_memory.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
|
# tests
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
"tests/ci_build/lint_python.py",
|
"tests/ci_build/lint_python.py",
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tempfile
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import xgboost
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
|
||||||
pytestmark = tm.timeout(30)
|
pytestmark = tm.timeout(30)
|
||||||
@ -138,10 +140,48 @@ def test_multioutput_reg() -> None:
|
|||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
# gpu_acceleration is not tested due to covertype dataset is being too huge.
|
@pytest.mark.skipif(**tm.no_ubjson())
|
||||||
# gamma regression is not tested as it requires running a R script first.
|
def test_json_model() -> None:
|
||||||
# aft viz is not tested due to ploting is not controled
|
script = os.path.join(DEMO_DIR, "json-model", "json_parser.py")
|
||||||
# aft tunning is not tested due to extra dependency.
|
|
||||||
|
def run_test(reg: xgboost.XGBRegressor) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = os.path.join(tmpdir, "reg.json")
|
||||||
|
reg.save_model(path)
|
||||||
|
cmd = ["python", script, f"--model={path}"]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
path = os.path.join(tmpdir, "reg.ubj")
|
||||||
|
reg.save_model(path)
|
||||||
|
cmd = ["python", script, f"--model={path}"]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
# numerical
|
||||||
|
X, y = tm.make_sparse_regression(100, 10, 0.5, False)
|
||||||
|
reg = xgboost.XGBRegressor(n_estimators=2, tree_method="hist")
|
||||||
|
reg.fit(X, y)
|
||||||
|
run_test(reg)
|
||||||
|
|
||||||
|
# categorical
|
||||||
|
X, y = tm.make_categorical(
|
||||||
|
n_samples=1000,
|
||||||
|
n_features=10,
|
||||||
|
n_categories=6,
|
||||||
|
onehot=False,
|
||||||
|
sparsity=0.5,
|
||||||
|
cat_ratio=0.5,
|
||||||
|
)
|
||||||
|
reg = xgboost.XGBRegressor(
|
||||||
|
n_estimators=2, tree_method="hist", enable_categorical=True
|
||||||
|
)
|
||||||
|
reg.fit(X, y)
|
||||||
|
run_test(reg)
|
||||||
|
|
||||||
|
|
||||||
|
# - gpu_acceleration is not tested due to covertype dataset is being too huge.
|
||||||
|
# - gamma regression is not tested as it requires running a R script first.
|
||||||
|
# - aft viz is not tested due to ploting is not controlled
|
||||||
|
# - aft tunning is not tested due to extra dependency.
|
||||||
|
|
||||||
|
|
||||||
def test_cli_regression_demo():
|
def test_cli_regression_demo():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user