Handle missing categorical value in CPU evaluator. (#7948)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from random import choice
|
||||
from string import ascii_lowercase
|
||||
from typing import Dict, Any
|
||||
import testing as tm
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
@@ -38,6 +39,9 @@ def train_result(param, dmat, num_rounds):
|
||||
|
||||
|
||||
class TestTreeMethod:
|
||||
USE_ONEHOT = np.iinfo(np.int32).max
|
||||
USE_PART = 1
|
||||
|
||||
@given(exact_parameter_strategy, strategies.integers(1, 20),
|
||||
tm.dataset_strategy)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@@ -213,10 +217,43 @@ class TestTreeMethod:
|
||||
def test_max_cat(self, tree_method) -> None:
|
||||
self.run_max_cat(tree_method)
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||
USE_ONEHOT = np.iinfo(np.int32).max
|
||||
USE_PART = 1
|
||||
def run_categorical_missing(
|
||||
self, rows: int, cols: int, cats: int, tree_method: str
|
||||
) -> None:
|
||||
parameters: Dict[str, Any] = {"tree_method": tree_method}
|
||||
cat, label = tm.make_categorical(
|
||||
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
|
||||
)
|
||||
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
|
||||
def run(max_cat_to_onehot: int):
|
||||
# Test with onehot splits
|
||||
parameters["max_cat_to_onehot"] = max_cat_to_onehot
|
||||
|
||||
evals_result: Dict[str, Dict] = {}
|
||||
booster = xgb.train(
|
||||
parameters,
|
||||
Xy,
|
||||
num_boost_round=16,
|
||||
evals=[(Xy, "Train")],
|
||||
evals_result=evals_result
|
||||
)
|
||||
assert tm.non_increasing(evals_result["Train"]["rmse"])
|
||||
y_predt = booster.predict(Xy)
|
||||
|
||||
rmse = tm.root_mean_square(label, y_predt)
|
||||
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
|
||||
|
||||
# Test with OHE split
|
||||
run(self.USE_ONEHOT)
|
||||
|
||||
if tree_method == "gpu_hist": # fixme: Test with GPU.
|
||||
return
|
||||
|
||||
# Test with partition-based split
|
||||
run(self.USE_PART)
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
@@ -226,7 +263,7 @@ class TestTreeMethod:
|
||||
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
||||
parameters = {"tree_method": tree_method, "predictor": predictor}
|
||||
# Use one-hot exclusively
|
||||
parameters["max_cat_to_onehot"] = USE_ONEHOT
|
||||
parameters["max_cat_to_onehot"] = self.USE_ONEHOT
|
||||
|
||||
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
||||
xgb.train(
|
||||
@@ -260,7 +297,7 @@ class TestTreeMethod:
|
||||
|
||||
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||
# switch to partition-based splits
|
||||
parameters["max_cat_to_onehot"] = USE_PART
|
||||
parameters["max_cat_to_onehot"] = self.USE_PART
|
||||
parameters["reg_lambda"] = 0
|
||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
xgb.train(
|
||||
@@ -287,27 +324,6 @@ class TestTreeMethod:
|
||||
)
|
||||
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
|
||||
|
||||
# test with missing values
|
||||
cat, label = tm.make_categorical(
|
||||
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
|
||||
)
|
||||
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
evals_result = {}
|
||||
# Test with onehot splits
|
||||
parameters["max_cat_to_onehot"] = USE_ONEHOT
|
||||
booster = xgb.train(
|
||||
parameters,
|
||||
Xy,
|
||||
num_boost_round=16,
|
||||
evals=[(Xy, "Train")],
|
||||
evals_result=evals_result
|
||||
)
|
||||
assert tm.non_increasing(evals_result["Train"]["rmse"])
|
||||
y_predt = booster.predict(Xy)
|
||||
|
||||
rmse = tm.root_mean_square(label, y_predt)
|
||||
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@@ -315,3 +331,14 @@ class TestTreeMethod:
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
|
||||
|
||||
@given(
|
||||
strategies.integers(10, 400),
|
||||
strategies.integers(3, 8),
|
||||
strategies.integers(4, 7)
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical_missing(self, rows, cols, cats):
|
||||
self.run_categorical_missing(rows, cols, cats, "approx")
|
||||
self.run_categorical_missing(rows, cols, cats, "hist")
|
||||
|
||||
Reference in New Issue
Block a user