Ensure models with categorical splits don't use old binary format. (#7666)
This commit is contained in:
parent
14d61b0141
commit
7366d3b20c
@ -868,6 +868,7 @@ void RegTree::Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
|
||||
CHECK_EQ(param.deprecated_num_roots, 1);
|
||||
CHECK_NE(param.num_nodes, 0);
|
||||
CHECK(!HasCategoricalSplit()) << "Please JSON/UBJSON for saving models with categorical splits.";
|
||||
|
||||
if (DMLC_IO_NO_ENDIAN_SWAP) {
|
||||
fo->Write(¶m, sizeof(TreeParam));
|
||||
|
||||
@ -70,7 +70,7 @@ class TestGPUUpdaters:
|
||||
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_invalid_categorical(self):
|
||||
def test_invalid_category(self):
|
||||
self.cputest.run_invalid_category("gpu_hist")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
|
||||
@ -381,6 +381,29 @@ class TestModels:
|
||||
'objective': 'multi:softmax'}
|
||||
validate_model(parameters)
|
||||
|
||||
def test_categorical_model_io(self):
|
||||
X, y = tm.make_categorical(256, 16, 71, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
booster = xgb.train({"tree_method": "approx"}, Xy, num_boost_round=16)
|
||||
predt_0 = booster.predict(Xy)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
path = os.path.join(tempdir, "model.binary")
|
||||
with pytest.raises(ValueError, match=r".*JSON/UBJSON.*"):
|
||||
booster.save_model(path)
|
||||
|
||||
path = os.path.join(tempdir, "model.json")
|
||||
booster.save_model(path)
|
||||
booster = xgb.Booster(model_file=path)
|
||||
predt_1 = booster.predict(Xy)
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
|
||||
path = os.path.join(tempdir, "model.ubj")
|
||||
booster.save_model(path)
|
||||
booster = xgb.Booster(model_file=path)
|
||||
predt_1 = booster.predict(Xy)
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_attributes(self):
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user