Handle missing values in one hot splits. (#7934)

This commit is contained in:
Jiaming Yuan
2022-05-24 20:48:41 +08:00
committed by GitHub
parent 18a38f7ca0
commit 606be9e663
3 changed files with 105 additions and 14 deletions

View File

@@ -214,6 +214,9 @@ class TestTreeMethod:
self.run_max_cat(tree_method)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
@@ -221,10 +224,9 @@ class TestTreeMethod:
by_builtin_results = {}
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
parameters = {"tree_method": tree_method, "predictor": predictor}
# Use one-hot exclusively
parameters = {
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
}
parameters["max_cat_to_onehot"] = USE_ONEHOT
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
@@ -257,7 +259,8 @@ class TestTreeMethod:
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
parameters["max_cat_to_onehot"] = 1
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
@@ -284,6 +287,27 @@ class TestTreeMethod:
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
# test with missing values
cat, label = tm.make_categorical(
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
evals_result = {}
# Test with onehot splits
parameters["max_cat_to_onehot"] = USE_ONEHOT
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)

View File

@@ -302,7 +302,7 @@ def get_mq2008(dpath):
@memory.cache
def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot: bool
n_samples: int, n_features: int, n_categories: int, onehot: bool, sparsity=0.0,
):
import pandas as pd
@@ -325,6 +325,13 @@ def make_categorical(
for col in df.columns:
df[col] = df[col].cat.set_categories(categories)
if sparsity > 0.0:
for i in range(n_features):
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity))
df.iloc[index, i] = np.NaN
assert df.iloc[:, i].isnull().values.any()
assert n_categories == np.unique(df.dtypes[i].categories).size
if onehot:
return pd.get_dummies(df), label
return df, label
@@ -538,6 +545,12 @@ def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float:
return np.sum(r)
def root_mean_square(y_true: np.ndarray, y_score: np.ndarray) -> float:
err = y_score - y_true
rmse = np.sqrt(np.dot(err, err) / y_score.size)
return rmse
def softmax(x):
e = np.exp(x)
return e / np.sum(e)