Check support status for categorical features. (#9946)

This commit is contained in:
Jiaming Yuan
2024-01-04 16:51:33 +08:00
committed by GitHub
parent db396ee340
commit c03a4d5088
7 changed files with 116 additions and 40 deletions

View File

@@ -1,3 +1,5 @@
import os
import tempfile
import weakref
from typing import Any, Callable, Dict, List
@@ -195,3 +197,39 @@ def test_data_cache() -> None:
assert called == 1
xgb.data._proxy_transform = transform
def test_cat_check() -> None:
n_batches = 3
n_features = 2
n_samples_per_batch = 16
batches = []
for i in range(n_batches):
X, y = tm.make_categorical(
n_samples=n_samples_per_batch,
n_features=n_features,
n_categories=3,
onehot=False,
)
batches.append((X, y))
X, y = list(zip(*batches))
it = tm.IteratorForTest(X, y, None, cache=None)
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"tree_method": "exact"}, Xy)
Xy = xgb.DMatrix(X[0], y[0], enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"tree_method": "exact"}, Xy)
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = os.path.join(tmpdir, "cache")
it = tm.IteratorForTest(X, y, None, cache=cache_path)
Xy = xgb.DMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"booster": "gblinear"}, Xy)