From 910ce580c893dad10b6c041ddb7ec2372fad800f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 14 Mar 2023 22:09:36 +0800 Subject: [PATCH] Clear all cache after model load. (#8904) --- include/xgboost/cache.h | 12 ++++++++++++ src/learner.cc | 4 ++++ tests/python/test_basic_models.py | 21 +++++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 781f45b1c..6195e730c 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -116,6 +116,18 @@ class DMatrixCache { * \param cache_size Maximum size of the cache. */ explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {} + + DMatrixCache& operator=(DMatrixCache&& that) { + CHECK(lock_.try_lock()); + lock_.unlock(); + CHECK(that.lock_.try_lock()); + that.lock_.unlock(); + std::swap(this->container_, that.container_); + std::swap(this->queue_, that.queue_); + std::swap(this->max_size_, that.max_size_); + return *this; + } + /** * \brief Cache a new DMatrix if it's not in the cache already. * diff --git a/src/learner.cc b/src/learner.cc index 62875ead6..d91add70d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -868,6 +868,8 @@ class LearnerIO : public LearnerConfiguration { // Will be removed once JSON takes over. Right now we still loads some RDS files from R. std::string const serialisation_header_ { u8"CONFIG-offset:" }; + void ClearCaches() { this->prediction_container_ = PredictionContainer{}; } + public: explicit LearnerIO(std::vector> cache) : LearnerConfiguration{cache} {} @@ -920,6 +922,7 @@ class LearnerIO : public LearnerConfiguration { } this->need_configuration_ = true; + this->ClearCaches(); } void SaveModel(Json* p_out) const override { @@ -1096,6 +1099,7 @@ class LearnerIO : public LearnerConfiguration { cfg_.insert(n.cbegin(), n.cend()); this->need_configuration_ = true; + this->ClearCaches(); } // Save model into binary format. The code is about to be deprecated by more robust diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 06f666da1..acacc55f8 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -234,6 +234,27 @@ class TestModels: xgb.cv(param, dtrain, num_round, nfold=5, metrics={'error'}, seed=0, show_stdv=False) + def test_prediction_cache(self) -> None: + X, y = tm.make_sparse_regression(512, 4, 0.5, as_dense=False) + Xy = xgb.DMatrix(X, y) + param = {"max_depth": 8} + booster = xgb.train(param, Xy, num_boost_round=1) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.json") + booster.save_model(path) + + predt_0 = booster.predict(Xy) + + param["max_depth"] = 2 + + booster = xgb.train(param, Xy, num_boost_round=1) + predt_1 = booster.predict(Xy) + assert not np.isclose(predt_0, predt_1).all() + + booster.load_model(path) + predt_2 = booster.predict(Xy) + np.testing.assert_allclose(predt_0, predt_2) + def test_feature_names_validation(self): X = np.random.random((10, 3)) y = np.random.randint(2, size=(10,))