Implement max_cat_threshold for CPU. (#7957)

This commit is contained in:
Jiaming Yuan 2022-06-04 11:02:46 +08:00 committed by GitHub
parent 78694405a6
commit b90c6d25e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 177 additions and 20 deletions

View File

@ -235,14 +235,19 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other. list is a group of indices of features that are allowed to interact with each other.
See :doc:`/tutorials/feature_interaction_constraint` for more information. See :doc:`/tutorials/feature_interaction_constraint` for more information.
Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method .. _cat-param:
===========================================================================
Parameters for Categorical Feature
==================================
These parameters are only used for training with categorical data. See
:doc:`/tutorials/categorical` for more information.
* ``max_cat_to_onehot`` * ``max_cat_to_onehot``
.. versionadded:: 1.6 .. versionadded:: 1.6
.. note:: The support for this parameter is experimental. .. note:: This parameter is experimental. ``exact`` tree method is not supported yet.
- A threshold for deciding whether XGBoost should use one-hot encoding based split for - A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot categorical data. When number of categories is lesser than the threshold then one-hot
@ -250,6 +255,16 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
Only relevant for regression and binary classification. Also, ``exact`` tree method is Only relevant for regression and binary classification. Also, ``exact`` tree method is
not supported not supported
* ``max_cat_threshold``
.. versionadded:: 2.0
.. note:: This parameter is experimental. ``exact`` and ``gpu_hist`` tree methods are
not supported yet.
- Maximum number of categories considered for each split. Used only by partition-based
splits for preventing over-fitting.
Additional parameters for Dart Booster (``booster=dart``) Additional parameters for Dart Booster (``booster=dart``)
========================================================= =========================================================

View File

@ -85,7 +85,7 @@ group the categories that output similar leaf values. During split finding, we f
the gradient histogram to prepare the contiguous partitions then enumerate the splits the gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be ``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details. used for each feature, see :ref:`cat-param` for details.
********************** **********************

View File

@ -54,7 +54,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
*/ */
template <bool validate = true> template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) { inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
CLBitField32 const s_cats(cats); KCatBitField const s_cats(cats);
// FIXME: Size() is not accurate since it represents the size of bit set instead of // FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories. // actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) { if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {

View File

@ -144,7 +144,8 @@ class HistEvaluator {
auto const &cut_ptr = cut.Ptrs(); auto const &cut_ptr = cut.Ptrs();
auto const &parent = snode_[nidx]; auto const &parent = snode_[nidx];
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])}; bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
// statistics on both sides of split // statistics on both sides of split
GradStats left_sum; GradStats left_sum;
@ -152,7 +153,7 @@ class HistEvaluator {
// best split so far // best split so far
SplitEntry best; SplitEntry best;
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
bst_bin_t ibegin, iend; bst_bin_t ibegin, iend;
bst_bin_t f_begin = cut_ptr[fidx]; bst_bin_t f_begin = cut_ptr[fidx];
if (d_step > 0) { if (d_step > 0) {
@ -160,7 +161,7 @@ class HistEvaluator {
iend = ibegin + n_bins - 1; iend = ibegin + n_bins - 1;
} else { } else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1; ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = f_begin; iend = ibegin - n_bins + 1;
} }
bst_bin_t best_thresh{-1}; bst_bin_t best_thresh{-1};
@ -177,7 +178,7 @@ class HistEvaluator {
auto loss_chg = auto loss_chg =
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain; parent.root_gain;
// We don't have a numeric split point, nan hare is a dummy split. // We don't have a numeric split point, nan here is a dummy split.
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true, if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
left_sum, right_sum)) { left_sum, right_sum)) {
best_thresh = i; best_thresh = i;
@ -186,10 +187,11 @@ class HistEvaluator {
} }
if (best_thresh != -1) { if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1); auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
best.cat_bits = decltype(best.cat_bits)(n, 0); best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits}; common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend; bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
[&](size_t c) { cat_bits.Set(c); }); [&](size_t c) { cat_bits.Set(c); });
} }

View File

@ -40,6 +40,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
uint32_t max_cat_to_onehot{4}; uint32_t max_cat_to_onehot{4};
bst_bin_t max_cat_threshold{64};
//----- the rest parameters are less important ---- //----- the rest parameters are less important ----
// minimum amount of hessian(weight) allowed in a child // minimum amount of hessian(weight) allowed in a child
float min_child_weight; float min_child_weight;
@ -113,6 +115,12 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
.set_default(4) .set_default(4)
.set_lower_bound(1) .set_lower_bound(1)
.describe("Maximum number of categories to use one-hot encoding based split."); .describe("Maximum number of categories to use one-hot encoding based split.");
DMLC_DECLARE_FIELD(max_cat_threshold)
.set_default(64)
.set_lower_bound(1)
.describe(
"Maximum number of categories considered for split. Used only by partition-based"
"splits.");
DMLC_DECLARE_FIELD(min_child_weight) DMLC_DECLARE_FIELD(min_child_weight)
.set_lower_bound(0.0f) .set_lower_bound(0.0f)
.set_default(1.0f) .set_default(1.0f)

View File

@ -74,8 +74,8 @@ class TestGPUUpdaters:
strategies.integers(1, 2), strategies.integers(4, 7)) strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats): def test_categorical_ohe(self, rows, cols, rounds, cats):
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
@given( @given(
strategies.integers(10, 400), strategies.integers(10, 400),
@ -96,7 +96,7 @@ class TestGPUUpdaters:
cols = 10 cols = 10
cats = 32 cats = 32
rounds = 4 rounds = 4
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_invalid_category(self): def test_invalid_category(self):

View File

@ -31,6 +31,14 @@ hist_parameter_strategy = strategies.fixed_dictionaries({
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide')) x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
cat_parameter_strategy = strategies.fixed_dictionaries(
{
"max_cat_to_onehot": strategies.integers(1, 128),
"max_cat_threshold": strategies.integers(1, 128),
}
)
def train_result(param, dmat, num_rounds): def train_result(param, dmat, num_rounds):
result = {} result = {}
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False, xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
@ -253,7 +261,7 @@ class TestTreeMethod:
# Test with partition-based split # Test with partition-based split
run(self.USE_PART) run(self.USE_PART)
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): def run_categorical_ohe(self, rows, cols, rounds, cats, tree_method):
onehot, label = tm.make_categorical(rows, cols, cats, True) onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False) cat, _ = tm.make_categorical(rows, cols, cats, False)
@ -328,9 +336,55 @@ class TestTreeMethod:
strategies.integers(1, 2), strategies.integers(4, 7)) strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats): def test_categorical_ohe(self, rows, cols, rounds, cats):
self.run_categorical_basic(rows, cols, rounds, cats, "approx") self.run_categorical_ohe(rows, cols, rounds, cats, "approx")
self.run_categorical_basic(rows, cols, rounds, cats, "hist") self.run_categorical_ohe(rows, cols, rounds, cats, "hist")
@given(
tm.categorical_dataset_strategy,
exact_parameter_strategy,
hist_parameter_strategy,
cat_parameter_strategy,
strategies.integers(4, 32),
strategies.sampled_from(["hist", "approx"]),
)
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(
self,
dataset: tm.TestDataset,
exact_parameters: Dict[str, Any],
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
n_rounds: int,
tree_method: str,
) -> None:
cat_parameters.update(exact_parameters)
cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = tree_method
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"])
@given(
hist_parameter_strategy,
cat_parameter_strategy,
strategies.sampled_from(["hist", "approx"]),
)
@settings(deadline=None, print_blob=True)
def test_categorical_ames_housing(
self,
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
tree_method: str,
) -> None:
cat_parameters.update(hist_parameters)
dataset = tm.TestDataset(
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
)
cat_parameters["tree_method"] = tree_method
results = train_result(cat_parameters, dataset.get_dmat(), 16)
tm.non_increasing(results["train"]["rmse"])
@given( @given(
strategies.integers(10, 400), strategies.integers(10, 400),

View File

@ -214,7 +214,9 @@ class TestDataset:
return params_in return params_in
def get_dmat(self): def get_dmat(self):
return xgb.DMatrix(self.X, self.y, self.w, base_margin=self.margin) return xgb.DMatrix(
self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True
)
def get_device_dmat(self): def get_device_dmat(self):
w = None if self.w is None else cp.array(self.w) w = None if self.w is None else cp.array(self.w)
@ -277,6 +279,48 @@ def get_sparse():
return X, y return X, y
@memory.cache
def get_ames_housing():
"""
Number of samples: 1460
Number of features: 20
Number of categorical features: 10
Number of numerical features: 10
"""
from sklearn.datasets import fetch_openml
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
categorical_columns_subset: list[str] = [
"BldgType", # 5 cats, no nan
"GarageFinish", # 3 cats, nan
"LotConfig", # 5 cats, no nan
"Functional", # 7 cats, no nan
"MasVnrType", # 4 cats, nan
"HouseStyle", # 8 cats, no nan
"FireplaceQu", # 5 cats, nan
"ExterCond", # 5 cats, no nan
"ExterQual", # 4 cats, no nan
"PoolQC", # 3 cats, nan
]
numerical_columns_subset: list[str] = [
"3SsnPorch",
"Fireplaces",
"BsmtHalfBath",
"HalfBath",
"GarageCars",
"TotRmsAbvGrd",
"BsmtFinSF1",
"BsmtFinSF2",
"GrLivArea",
"ScreenPorch",
]
X = X[categorical_columns_subset + numerical_columns_subset]
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")
return X, y
@memory.cache @memory.cache
def get_mq2008(dpath): def get_mq2008(dpath):
from sklearn.datasets import load_svmlight_files from sklearn.datasets import load_svmlight_files
@ -329,7 +373,6 @@ def make_categorical(
for i in range(n_features): for i in range(n_features):
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity)) index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity))
df.iloc[index, i] = np.NaN df.iloc[index, i] = np.NaN
assert df.iloc[:, i].isnull().values.any()
assert n_categories == np.unique(df.dtypes[i].categories).size assert n_categories == np.unique(df.dtypes[i].categories).size
if onehot: if onehot:
@ -337,6 +380,41 @@ def make_categorical(
return df, label return df, label
def _cat_sampled_from():
@strategies.composite
def _make_cat(draw):
n_samples = draw(strategies.integers(2, 512))
n_features = draw(strategies.integers(1, 4))
n_cats = draw(strategies.integers(1, 128))
sparsity = draw(
strategies.floats(
min_value=0,
max_value=1,
allow_nan=False,
allow_infinity=False,
allow_subnormal=False,
)
)
return n_samples, n_features, n_cats, sparsity
def _build(args):
n_samples = args[0]
n_features = args[1]
n_cats = args[2]
sparsity = args[3]
return TestDataset(
f"{n_samples}x{n_features}-{n_cats}-{sparsity}",
lambda: make_categorical(n_samples, n_features, n_cats, False, sparsity),
"reg:squarederror",
"rmse",
)
return _make_cat().map(_build)
categorical_dataset_strategy = _cat_sampled_from()
@memory.cache @memory.cache
def make_sparse_regression( def make_sparse_regression(
n_samples: int, n_features: int, sparsity: float, as_dense: bool n_samples: int, n_features: int, sparsity: float, as_dense: bool