Handle categorical split in model histogram and dataframe. (#7065)
* Error on get_split_value_histogram when feature is categorical * Add a category column to output dataframe
This commit is contained in:
parent
1cd20efe68
commit
a5d222fcdb
@ -2225,7 +2225,7 @@ class Booster(object):
|
||||
results[feat] = float(score)
|
||||
return results
|
||||
|
||||
def trees_to_dataframe(self, fmap=''):
|
||||
def trees_to_dataframe(self, fmap=''): # pylint: disable=too-many-statements
|
||||
"""Parse a boosted tree model text dump into a pandas DataFrame structure.
|
||||
|
||||
This feature is only defined when the decision tree model is chosen as base
|
||||
@ -2251,6 +2251,7 @@ class Booster(object):
|
||||
node_ids = []
|
||||
fids = []
|
||||
splits = []
|
||||
categories = []
|
||||
y_directs = []
|
||||
n_directs = []
|
||||
missings = []
|
||||
@ -2275,6 +2276,7 @@ class Booster(object):
|
||||
node_ids.append(int(re.findall(r'\b\d+\b', parse[0])[0]))
|
||||
fids.append('Leaf')
|
||||
splits.append(float('NAN'))
|
||||
categories.append(float('NAN'))
|
||||
y_directs.append(float('NAN'))
|
||||
n_directs.append(float('NAN'))
|
||||
missings.append(float('NAN'))
|
||||
@ -2284,14 +2286,26 @@ class Booster(object):
|
||||
else:
|
||||
# parse string
|
||||
fid = arr[1].split(']')
|
||||
parse = fid[0].split('<')
|
||||
if fid[0].find("<") != -1:
|
||||
# numerical
|
||||
parse = fid[0].split('<')
|
||||
splits.append(float(parse[1]))
|
||||
categories.append(None)
|
||||
elif fid[0].find(":{") != -1:
|
||||
# categorical
|
||||
parse = fid[0].split(":")
|
||||
cats = parse[1][1:-1] # strip the {}
|
||||
cats = cats.split(",")
|
||||
splits.append(float("NAN"))
|
||||
categories.append(cats if cats else None)
|
||||
else:
|
||||
raise ValueError("Failed to parse model text dump.")
|
||||
stats = re.split('=|,', fid[1])
|
||||
|
||||
# append to lists
|
||||
tree_ids.append(i)
|
||||
node_ids.append(int(re.findall(r'\b\d+\b', arr[0])[0]))
|
||||
fids.append(parse[0])
|
||||
splits.append(float(parse[1]))
|
||||
str_i = str(i)
|
||||
y_directs.append(str_i + '-' + stats[1])
|
||||
n_directs.append(str_i + '-' + stats[3])
|
||||
@ -2303,7 +2317,7 @@ class Booster(object):
|
||||
df = DataFrame({'Tree': tree_ids, 'Node': node_ids, 'ID': ids,
|
||||
'Feature': fids, 'Split': splits, 'Yes': y_directs,
|
||||
'No': n_directs, 'Missing': missings, 'Gain': gains,
|
||||
'Cover': covers})
|
||||
'Cover': covers, "Category": categories})
|
||||
|
||||
if callable(getattr(df, 'sort_values', None)):
|
||||
# pylint: disable=no-member
|
||||
@ -2381,9 +2395,29 @@ class Booster(object):
|
||||
nph = np.column_stack((nph[1][1:], nph[0]))
|
||||
nph = nph[nph[:, 1] > 0]
|
||||
|
||||
if nph.size == 0:
|
||||
ft = self.feature_types
|
||||
fn = self.feature_names
|
||||
if fn is None:
|
||||
# Let xgboost generate the feature names.
|
||||
fn = ["f{0}".format(i) for i in range(self.num_features())]
|
||||
try:
|
||||
index = fn.index(feature)
|
||||
feature_t = ft[index]
|
||||
except (ValueError, AttributeError, TypeError):
|
||||
# None.index: attr err, None[0]: type err, fn.index(-1): value err
|
||||
feature_t = None
|
||||
if feature_t == "categorical":
|
||||
raise ValueError(
|
||||
"Split value historgam doesn't support categorical split."
|
||||
)
|
||||
|
||||
if as_pandas and PANDAS_INSTALLED:
|
||||
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||
if as_pandas and not PANDAS_INSTALLED:
|
||||
sys.stderr.write(
|
||||
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
||||
warnings.warn(
|
||||
"Returning histogram as ndarray"
|
||||
" (as_pandas == True, but pandas is not installed).",
|
||||
UserWarning
|
||||
)
|
||||
return nph
|
||||
|
||||
25
tests/python-gpu/test_gpu_parse_tree.py
Normal file
25
tests/python-gpu/test_gpu_parse_tree.py
Normal file
@ -0,0 +1,25 @@
|
||||
import sys
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
|
||||
def test_tree_to_df_categorical():
|
||||
X, y = tm.make_categorical(100, 10, 31, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10)
|
||||
df = booster.trees_to_dataframe()
|
||||
for _, x in df.iterrows():
|
||||
if x["Feature"] != "Leaf":
|
||||
assert len(x["Category"]) == 1
|
||||
|
||||
|
||||
def test_split_value_histograms():
|
||||
X, y = tm.make_categorical(1000, 10, 13, False)
|
||||
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
|
||||
reg.fit(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match="doesn't"):
|
||||
reg.get_booster().get_split_value_histogram("3", bins=5)
|
||||
@ -32,15 +32,14 @@ def train_result(param, dmat, num_rounds):
|
||||
|
||||
|
||||
class TestGPUUpdaters:
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
tm.dataset_strategy)
|
||||
@given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_gpu_hist(self, param, num_rounds, dataset):
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
param["tree_method"] = "gpu_hist"
|
||||
param = dataset.set_params(param)
|
||||
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||
note(result)
|
||||
assert tm.non_increasing(result['train'][dataset.metric])
|
||||
assert tm.non_increasing(result["train"][dataset.metric])
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
@ -49,25 +48,40 @@ class TestGPUUpdaters:
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
parameters = {'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor'}
|
||||
parameters = {"tree_method": "gpu_hist", "predictor": "gpu_predictor"}
|
||||
|
||||
m = xgb.DMatrix(onehot, label, enable_categorical=True)
|
||||
xgb.train(parameters, m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, 'Train')], evals_result=by_etl_results)
|
||||
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_etl_results,
|
||||
)
|
||||
|
||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
xgb.train(parameters, m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, 'Train')], evals_result=by_builtin_results)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_builtin_results,
|
||||
)
|
||||
|
||||
# There are guidelines on how to specify tolerance based on considering output as
|
||||
# random variables. But in here the tree construction is extremely sensitive to
|
||||
# floating point errors. An 1e-5 error in a histogram bin can lead to an entirely
|
||||
# different tree. So even though the test is quite lenient, hypothesis can still
|
||||
# pick up falsifying examples from time to time.
|
||||
np.testing.assert_allclose(
|
||||
np.array(by_etl_results['Train']['rmse']),
|
||||
np.array(by_builtin_results['Train']['rmse']),
|
||||
rtol=1e-3)
|
||||
assert tm.non_increasing(by_builtin_results['Train']['rmse'])
|
||||
np.array(by_etl_results["Train"]["rmse"]),
|
||||
np.array(by_builtin_results["Train"]["rmse"]),
|
||||
rtol=1e-3,
|
||||
)
|
||||
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 5), strategies.integers(4, 7))
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user