Implement max_cat_threshold for CPU. (#7957)
This commit is contained in:
parent
78694405a6
commit
b90c6d25e8
@ -235,14 +235,19 @@ Parameters for Tree Booster
|
||||
list is a group of indices of features that are allowed to interact with each other.
|
||||
See :doc:`/tutorials/feature_interaction_constraint` for more information.
|
||||
|
||||
Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
||||
===========================================================================
|
||||
.. _cat-param:
|
||||
|
||||
Parameters for Categorical Feature
|
||||
==================================
|
||||
|
||||
These parameters are only used for training with categorical data. See
|
||||
:doc:`/tutorials/categorical` for more information.
|
||||
|
||||
* ``max_cat_to_onehot``
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
.. note:: The support for this parameter is experimental.
|
||||
.. note:: This parameter is experimental. ``exact`` tree method is not supported yet.
|
||||
|
||||
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
|
||||
categorical data. When number of categories is lesser than the threshold then one-hot
|
||||
@ -250,6 +255,16 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
||||
Only relevant for regression and binary classification. Also, ``exact`` tree method is
|
||||
not supported
|
||||
|
||||
* ``max_cat_threshold``
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. note:: This parameter is experimental. ``exact`` and ``gpu_hist`` tree methods are
|
||||
not supported yet.
|
||||
|
||||
- Maximum number of categories considered for each split. Used only by partition-based
|
||||
splits for preventing over-fitting.
|
||||
|
||||
Additional parameters for Dart Booster (``booster=dart``)
|
||||
=========================================================
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ group the categories that output similar leaf values. During split finding, we f
|
||||
the gradient histogram to prepare the contiguous partitions then enumerate the splits
|
||||
according to these sorted values. One of the related parameters for XGBoost is
|
||||
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
|
||||
used for each feature, see :doc:`/parameter` for details.
|
||||
used for each feature, see :ref:`cat-param` for details.
|
||||
|
||||
|
||||
**********************
|
||||
|
||||
@ -54,7 +54,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
|
||||
*/
|
||||
template <bool validate = true>
|
||||
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
|
||||
CLBitField32 const s_cats(cats);
|
||||
KCatBitField const s_cats(cats);
|
||||
// FIXME: Size() is not accurate since it represents the size of bit set instead of
|
||||
// actual number of categories.
|
||||
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
|
||||
|
||||
@ -144,7 +144,8 @@ class HistEvaluator {
|
||||
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
auto const &parent = snode_[nidx];
|
||||
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};
|
||||
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
|
||||
|
||||
// statistics on both sides of split
|
||||
GradStats left_sum;
|
||||
@ -152,7 +153,7 @@ class HistEvaluator {
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
|
||||
bst_bin_t ibegin, iend;
|
||||
bst_bin_t f_begin = cut_ptr[fidx];
|
||||
if (d_step > 0) {
|
||||
@ -160,7 +161,7 @@ class HistEvaluator {
|
||||
iend = ibegin + n_bins - 1;
|
||||
} else {
|
||||
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
|
||||
iend = f_begin;
|
||||
iend = ibegin - n_bins + 1;
|
||||
}
|
||||
|
||||
bst_bin_t best_thresh{-1};
|
||||
@ -177,7 +178,7 @@ class HistEvaluator {
|
||||
auto loss_chg =
|
||||
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
parent.root_gain;
|
||||
// We don't have a numeric split point, nan hare is a dummy split.
|
||||
// We don't have a numeric split point, nan here is a dummy split.
|
||||
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
|
||||
left_sum, right_sum)) {
|
||||
best_thresh = i;
|
||||
@ -186,10 +187,11 @@ class HistEvaluator {
|
||||
}
|
||||
|
||||
if (best_thresh != -1) {
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
|
||||
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
|
||||
best.cat_bits = decltype(best.cat_bits)(n, 0);
|
||||
common::CatBitField cat_bits{best.cat_bits};
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
|
||||
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
|
||||
CHECK_GT(partition, 0);
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
|
||||
[&](size_t c) { cat_bits.Set(c); });
|
||||
}
|
||||
|
||||
@ -40,6 +40,8 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
|
||||
uint32_t max_cat_to_onehot{4};
|
||||
|
||||
bst_bin_t max_cat_threshold{64};
|
||||
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
float min_child_weight;
|
||||
@ -113,6 +115,12 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
.set_default(4)
|
||||
.set_lower_bound(1)
|
||||
.describe("Maximum number of categories to use one-hot encoding based split.");
|
||||
DMLC_DECLARE_FIELD(max_cat_threshold)
|
||||
.set_default(64)
|
||||
.set_lower_bound(1)
|
||||
.describe(
|
||||
"Maximum number of categories considered for split. Used only by partition-based"
|
||||
"splits.");
|
||||
DMLC_DECLARE_FIELD(min_child_weight)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(1.0f)
|
||||
|
||||
@ -74,8 +74,8 @@ class TestGPUUpdaters:
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||
def test_categorical_ohe(self, rows, cols, rounds, cats):
|
||||
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
@given(
|
||||
strategies.integers(10, 400),
|
||||
@ -96,7 +96,7 @@ class TestGPUUpdaters:
|
||||
cols = 10
|
||||
cats = 32
|
||||
rounds = 4
|
||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_invalid_category(self):
|
||||
|
||||
@ -31,6 +31,14 @@ hist_parameter_strategy = strategies.fixed_dictionaries({
|
||||
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
|
||||
|
||||
|
||||
cat_parameter_strategy = strategies.fixed_dictionaries(
|
||||
{
|
||||
"max_cat_to_onehot": strategies.integers(1, 128),
|
||||
"max_cat_threshold": strategies.integers(1, 128),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def train_result(param, dmat, num_rounds):
|
||||
result = {}
|
||||
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
|
||||
@ -253,7 +261,7 @@ class TestTreeMethod:
|
||||
# Test with partition-based split
|
||||
run(self.USE_PART)
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||
def run_categorical_ohe(self, rows, cols, rounds, cats, tree_method):
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
@ -328,9 +336,55 @@ class TestTreeMethod:
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "approx")
|
||||
self.run_categorical_basic(rows, cols, rounds, cats, "hist")
|
||||
def test_categorical_ohe(self, rows, cols, rounds, cats):
|
||||
self.run_categorical_ohe(rows, cols, rounds, cats, "approx")
|
||||
self.run_categorical_ohe(rows, cols, rounds, cats, "hist")
|
||||
|
||||
@given(
|
||||
tm.categorical_dataset_strategy,
|
||||
exact_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
cat_parameter_strategy,
|
||||
strategies.integers(4, 32),
|
||||
strategies.sampled_from(["hist", "approx"]),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_categorical(
|
||||
self,
|
||||
dataset: tm.TestDataset,
|
||||
exact_parameters: Dict[str, Any],
|
||||
hist_parameters: Dict[str, Any],
|
||||
cat_parameters: Dict[str, Any],
|
||||
n_rounds: int,
|
||||
tree_method: str,
|
||||
) -> None:
|
||||
cat_parameters.update(exact_parameters)
|
||||
cat_parameters.update(hist_parameters)
|
||||
cat_parameters["tree_method"] = tree_method
|
||||
|
||||
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
|
||||
tm.non_increasing(results["train"]["rmse"])
|
||||
|
||||
@given(
|
||||
hist_parameter_strategy,
|
||||
cat_parameter_strategy,
|
||||
strategies.sampled_from(["hist", "approx"]),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_categorical_ames_housing(
|
||||
self,
|
||||
hist_parameters: Dict[str, Any],
|
||||
cat_parameters: Dict[str, Any],
|
||||
tree_method: str,
|
||||
) -> None:
|
||||
cat_parameters.update(hist_parameters)
|
||||
dataset = tm.TestDataset(
|
||||
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
|
||||
)
|
||||
cat_parameters["tree_method"] = tree_method
|
||||
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
||||
tm.non_increasing(results["train"]["rmse"])
|
||||
|
||||
@given(
|
||||
strategies.integers(10, 400),
|
||||
|
||||
@ -214,7 +214,9 @@ class TestDataset:
|
||||
return params_in
|
||||
|
||||
def get_dmat(self):
|
||||
return xgb.DMatrix(self.X, self.y, self.w, base_margin=self.margin)
|
||||
return xgb.DMatrix(
|
||||
self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True
|
||||
)
|
||||
|
||||
def get_device_dmat(self):
|
||||
w = None if self.w is None else cp.array(self.w)
|
||||
@ -277,6 +279,48 @@ def get_sparse():
|
||||
return X, y
|
||||
|
||||
|
||||
@memory.cache
|
||||
def get_ames_housing():
|
||||
"""
|
||||
Number of samples: 1460
|
||||
Number of features: 20
|
||||
Number of categorical features: 10
|
||||
Number of numerical features: 10
|
||||
"""
|
||||
from sklearn.datasets import fetch_openml
|
||||
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
|
||||
|
||||
categorical_columns_subset: list[str] = [
|
||||
"BldgType", # 5 cats, no nan
|
||||
"GarageFinish", # 3 cats, nan
|
||||
"LotConfig", # 5 cats, no nan
|
||||
"Functional", # 7 cats, no nan
|
||||
"MasVnrType", # 4 cats, nan
|
||||
"HouseStyle", # 8 cats, no nan
|
||||
"FireplaceQu", # 5 cats, nan
|
||||
"ExterCond", # 5 cats, no nan
|
||||
"ExterQual", # 4 cats, no nan
|
||||
"PoolQC", # 3 cats, nan
|
||||
]
|
||||
|
||||
numerical_columns_subset: list[str] = [
|
||||
"3SsnPorch",
|
||||
"Fireplaces",
|
||||
"BsmtHalfBath",
|
||||
"HalfBath",
|
||||
"GarageCars",
|
||||
"TotRmsAbvGrd",
|
||||
"BsmtFinSF1",
|
||||
"BsmtFinSF2",
|
||||
"GrLivArea",
|
||||
"ScreenPorch",
|
||||
]
|
||||
|
||||
X = X[categorical_columns_subset + numerical_columns_subset]
|
||||
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")
|
||||
return X, y
|
||||
|
||||
|
||||
@memory.cache
|
||||
def get_mq2008(dpath):
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
@ -329,7 +373,6 @@ def make_categorical(
|
||||
for i in range(n_features):
|
||||
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity))
|
||||
df.iloc[index, i] = np.NaN
|
||||
assert df.iloc[:, i].isnull().values.any()
|
||||
assert n_categories == np.unique(df.dtypes[i].categories).size
|
||||
|
||||
if onehot:
|
||||
@ -337,6 +380,41 @@ def make_categorical(
|
||||
return df, label
|
||||
|
||||
|
||||
def _cat_sampled_from():
|
||||
@strategies.composite
|
||||
def _make_cat(draw):
|
||||
n_samples = draw(strategies.integers(2, 512))
|
||||
n_features = draw(strategies.integers(1, 4))
|
||||
n_cats = draw(strategies.integers(1, 128))
|
||||
sparsity = draw(
|
||||
strategies.floats(
|
||||
min_value=0,
|
||||
max_value=1,
|
||||
allow_nan=False,
|
||||
allow_infinity=False,
|
||||
allow_subnormal=False,
|
||||
)
|
||||
)
|
||||
return n_samples, n_features, n_cats, sparsity
|
||||
|
||||
def _build(args):
|
||||
n_samples = args[0]
|
||||
n_features = args[1]
|
||||
n_cats = args[2]
|
||||
sparsity = args[3]
|
||||
return TestDataset(
|
||||
f"{n_samples}x{n_features}-{n_cats}-{sparsity}",
|
||||
lambda: make_categorical(n_samples, n_features, n_cats, False, sparsity),
|
||||
"reg:squarederror",
|
||||
"rmse",
|
||||
)
|
||||
|
||||
return _make_cat().map(_build)
|
||||
|
||||
|
||||
categorical_dataset_strategy = _cat_sampled_from()
|
||||
|
||||
|
||||
@memory.cache
|
||||
def make_sparse_regression(
|
||||
n_samples: int, n_features: int, sparsity: float, as_dense: bool
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user