Check support status for categorical features. (#9946)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
@@ -195,3 +197,39 @@ def test_data_cache() -> None:
|
||||
assert called == 1
|
||||
|
||||
xgb.data._proxy_transform = transform
|
||||
|
||||
|
||||
def test_cat_check() -> None:
|
||||
n_batches = 3
|
||||
n_features = 2
|
||||
n_samples_per_batch = 16
|
||||
|
||||
batches = []
|
||||
|
||||
for i in range(n_batches):
|
||||
X, y = tm.make_categorical(
|
||||
n_samples=n_samples_per_batch,
|
||||
n_features=n_features,
|
||||
n_categories=3,
|
||||
onehot=False,
|
||||
)
|
||||
batches.append((X, y))
|
||||
|
||||
X, y = list(zip(*batches))
|
||||
it = tm.IteratorForTest(X, y, None, cache=None)
|
||||
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)
|
||||
|
||||
with pytest.raises(ValueError, match="categorical features"):
|
||||
xgb.train({"tree_method": "exact"}, Xy)
|
||||
|
||||
Xy = xgb.DMatrix(X[0], y[0], enable_categorical=True)
|
||||
with pytest.raises(ValueError, match="categorical features"):
|
||||
xgb.train({"tree_method": "exact"}, Xy)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache_path = os.path.join(tmpdir, "cache")
|
||||
|
||||
it = tm.IteratorForTest(X, y, None, cache=cache_path)
|
||||
Xy = xgb.DMatrix(it, enable_categorical=True)
|
||||
with pytest.raises(ValueError, match="categorical features"):
|
||||
xgb.train({"booster": "gblinear"}, Xy)
|
||||
|
||||
Reference in New Issue
Block a user