Handle categorical split in model histogram and dataframe. (#7065)

* Error on get_split_value_histogram when feature is categorical
* Add a category column to output dataframe
This commit is contained in:
Jiaming Yuan 2021-07-02 13:10:36 +08:00 committed by GitHub
parent 1cd20efe68
commit a5d222fcdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 23 deletions

View File

@ -2225,7 +2225,7 @@ class Booster(object):
results[feat] = float(score) results[feat] = float(score)
return results return results
def trees_to_dataframe(self, fmap=''): def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements
"""Parse a boosted tree model text dump into a pandas DataFrame structure. """Parse a boosted tree model text dump into a pandas DataFrame structure.
This feature is only defined when the decision tree model is chosen as base This feature is only defined when the decision tree model is chosen as base
@ -2251,6 +2251,7 @@ class Booster(object):
node_ids = [] node_ids = []
fids = [] fids = []
splits = [] splits = []
categories = []
y_directs = [] y_directs = []
n_directs = [] n_directs = []
missings = [] missings = []
@ -2275,6 +2276,7 @@ class Booster(object):
node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0])) node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0]))
fids.append('Leaf') fids.append('Leaf')
splits.append(float('NAN')) splits.append(float('NAN'))
categories.append(float('NAN'))
y_directs.append(float('NAN')) y_directs.append(float('NAN'))
n_directs.append(float('NAN')) n_directs.append(float('NAN'))
missings.append(float('NAN')) missings.append(float('NAN'))
@ -2284,14 +2286,26 @@ class Booster(object):
else: else:
# parse string # parse string
fid = arr[1].split(']') fid = arr[1].split(']')
if fid[0].find("<") != -1:
# numerical
parse = fid[0].split('<') parse = fid[0].split('<')
splits.append(float(parse[1]))
categories.append(None)
elif fid[0].find(":{") != -1:
# categorical
parse = fid[0].split(":")
cats = parse[1][1:-1] # strip the {}
cats = cats.split(",")
splits.append(float("NAN"))
categories.append(cats if cats else None)
else:
raise ValueError("Failed to parse model text dump.")
stats = re.split('=|,', fid[1]) stats = re.split('=|,', fid[1])
# append to lists # append to lists
tree_ids.append(i) tree_ids.append(i)
node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0])) node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0]))
fids.append(parse[0]) fids.append(parse[0])
splits.append(float(parse[1]))
str_i = str(i) str_i = str(i)
y_directs.append(str_i + '-' + stats[1]) y_directs.append(str_i + '-' + stats[1])
n_directs.append(str_i + '-' + stats[3]) n_directs.append(str_i + '-' + stats[3])
@ -2303,7 +2317,7 @@ class Booster(object):
df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids, df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids,
'Feature': fids, 'Split': splits, 'Yes': y_directs, 'Feature': fids, 'Split': splits, 'Yes': y_directs,
'No': n_directs, 'Missing': missings, 'Gain': gains, 'No': n_directs, 'Missing': missings, 'Gain': gains,
'Cover': covers}) 'Cover': covers, "Category": categories})
if callable(getattr(df, 'sort_values', None)): if callable(getattr(df, 'sort_values', None)):
# pylint: disable=no-member # pylint: disable=no-member
@ -2381,9 +2395,29 @@ class Booster(object):
nph = np.column_stack((nph[1][1:], nph[0])) nph = np.column_stack((nph[1][1:], nph[0]))
nph = nph[nph[:, 1] > 0] nph = nph[nph[:, 1] > 0]
if nph.size == 0:
ft = self.feature_types
fn = self.feature_names
if fn is None:
# Let xgboost generate the feature names.
fn = ["f{0}".format(i) for i in range(self.num_features())]
try:
index = fn.index(feature)
feature_t = ft[index]
except (ValueError, AttributeError, TypeError):
# None.index: attr err, None[0]: type err, fn.index(-1): value err
feature_t = None
if feature_t == "categorical":
raise ValueError(
"Split value historgam doesn't support categorical split."
)
if as_pandas and PANDAS_INSTALLED: if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count']) return DataFrame(nph, columns=['SplitValue', 'Count'])
if as_pandas and not PANDAS_INSTALLED: if as_pandas and not PANDAS_INSTALLED:
sys.stderr.write( warnings.warn(
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).") "Returning histogram as ndarray"
" (as_pandas == True, but pandas is not installed).",
UserWarning
)
return nph return nph

View File

@ -0,0 +1,25 @@
import sys
import pytest
import xgboost as xgb
sys.path.append("tests/python")
import testing as tm
def test_tree_to_df_categorical():
X, y = tm.make_categorical(100, 10, 31, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10)
df = booster.trees_to_dataframe()
for _, x in df.iterrows():
if x["Feature"] != "Leaf":
assert len(x["Category"]) == 1
def test_split_value_histograms():
X, y = tm.make_categorical(1000, 10, 13, False)
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
reg.fit(X, y)
with pytest.raises(ValueError, match="doesn't"):
reg.get_booster().get_split_value_histogram("3", bins=5)

View File

@ -32,15 +32,14 @@ def train_result(param, dmat, num_rounds):
class TestGPUUpdaters: class TestGPUUpdaters:
@given(parameter_strategy, strategies.integers(1, 20), @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy)
tm.dataset_strategy)
@settings(deadline=None) @settings(deadline=None)
def test_gpu_hist(self, param, num_rounds, dataset): def test_gpu_hist(self, param, num_rounds, dataset):
param['tree_method'] = 'gpu_hist' param["tree_method"] = "gpu_hist"
param = dataset.set_params(param) param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds) result = train_result(param, dataset.get_dmat(), num_rounds)
note(result) note(result)
assert tm.non_increasing(result['train'][dataset.metric]) assert tm.non_increasing(result["train"][dataset.metric])
def run_categorical_basic(self, rows, cols, rounds, cats): def run_categorical_basic(self, rows, cols, rounds, cats):
onehot, label = tm.make_categorical(rows, cols, cats, True) onehot, label = tm.make_categorical(rows, cols, cats, True)
@ -49,25 +48,40 @@ class TestGPUUpdaters:
by_etl_results = {} by_etl_results = {}
by_builtin_results = {} by_builtin_results = {}
parameters = {'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor'} parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
m = xgb.DMatrix(onehot, label, enable_categorical=True) m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(parameters, m, xgb.train(
parameters,
m,
num_boost_round=rounds, num_boost_round=rounds,
evals=[(m, 'Train')], evals_result=by_etl_results) evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True) m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(parameters, m, xgb.train(
parameters,
m,
num_boost_round=rounds, num_boost_round=rounds,
evals=[(m, 'Train')], evals_result=by_builtin_results) evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output as
# random variables. But in here the tree construction is extremely sensitive to
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
# different tree. So even though the test is quite lenient, hypothesis can still
# pick up falsifying examples from time to time.
np.testing.assert_allclose( np.testing.assert_allclose(
np.array(by_etl_results['Train']['rmse']), np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results['Train']['rmse']), np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3) rtol=1e-3,
assert tm.non_increasing(by_builtin_results['Train']['rmse']) )
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
@given(strategies.integers(10, 400), strategies.integers(3, 8), @given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 5), strategies.integers(4, 7)) strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None) @settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats): def test_categorical(self, rows, cols, rounds, cats):