Ensure models with categorical splits don't use old binary format. (#7666)

This commit is contained in:
Jiaming Yuan 2022-02-19 08:05:28 +08:00 committed by GitHub
parent 14d61b0141
commit 7366d3b20c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 1 deletions

View File

@ -868,6 +868,7 @@ void RegTree::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
CHECK_EQ(param.deprecated_num_roots, 1);
CHECK_NE(param.num_nodes, 0);
CHECK(!HasCategoricalSplit()) << "Please JSON/UBJSON for saving models with categorical splits.";
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&param, sizeof(TreeParam));

View File

@ -70,7 +70,7 @@ class TestGPUUpdaters:
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")
@pytest.mark.skipif(**tm.no_cupy())
def test_invalid_categorical(self):
def test_invalid_category(self):
self.cputest.run_invalid_category("gpu_hist")
@pytest.mark.skipif(**tm.no_cupy())

View File

@ -381,6 +381,29 @@ class TestModels:
'objective': 'multi:softmax'}
validate_model(parameters)
def test_categorical_model_io(self):
X, y = tm.make_categorical(256, 16, 71, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "approx"}, Xy, num_boost_round=16)
predt_0 = booster.predict(Xy)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, "model.binary")
with pytest.raises(ValueError, match=r".*JSON/UBJSON.*"):
booster.save_model(path)
path = os.path.join(tempdir, "model.json")
booster.save_model(path)
booster = xgb.Booster(model_file=path)
predt_1 = booster.predict(Xy)
np.testing.assert_allclose(predt_0, predt_1)
path = os.path.join(tempdir, "model.ubj")
booster.save_model(path)
booster = xgb.Booster(model_file=path)
predt_1 = booster.predict(Xy)
np.testing.assert_allclose(predt_0, predt_1)
@pytest.mark.skipif(**tm.no_sklearn())
def test_attributes(self):
from sklearn.datasets import load_iris