Clear all cache after model load. (#8904)

This commit is contained in:
Jiaming Yuan 2023-03-14 22:09:36 +08:00 committed by GitHub
parent c400fa1e8d
commit 910ce580c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 0 deletions

View File

@ -116,6 +116,18 @@ class DMatrixCache {
* \param cache_size Maximum size of the cache.
*/
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
DMatrixCache& operator=(DMatrixCache&& that) {
CHECK(lock_.try_lock());
lock_.unlock();
CHECK(that.lock_.try_lock());
that.lock_.unlock();
std::swap(this->container_, that.container_);
std::swap(this->queue_, that.queue_);
std::swap(this->max_size_, that.max_size_);
return *this;
}
/**
* \brief Cache a new DMatrix if it's not in the cache already.
*

View File

@ -868,6 +868,8 @@ class LearnerIO : public LearnerConfiguration {
// Will be removed once JSON takes over. Right now we still loads some RDS files from R.
std::string const serialisation_header_ { u8"CONFIG-offset:" };
void ClearCaches() { this->prediction_container_ = PredictionContainer{}; }
public:
explicit LearnerIO(std::vector<std::shared_ptr<DMatrix>> cache) : LearnerConfiguration{cache} {}
@ -920,6 +922,7 @@ class LearnerIO : public LearnerConfiguration {
}
this->need_configuration_ = true;
this->ClearCaches();
}
void SaveModel(Json* p_out) const override {
@ -1096,6 +1099,7 @@ class LearnerIO : public LearnerConfiguration {
cfg_.insert(n.cbegin(), n.cend());
this->need_configuration_ = true;
this->ClearCaches();
}
// Save model into binary format. The code is about to be deprecated by more robust

View File

@ -234,6 +234,27 @@ class TestModels:
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0, show_stdv=False)
def test_prediction_cache(self) -> None:
X, y = tm.make_sparse_regression(512, 4, 0.5, as_dense=False)
Xy = xgb.DMatrix(X, y)
param = {"max_depth": 8}
booster = xgb.train(param, Xy, num_boost_round=1)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.json")
booster.save_model(path)
predt_0 = booster.predict(Xy)
param["max_depth"] = 2
booster = xgb.train(param, Xy, num_boost_round=1)
predt_1 = booster.predict(Xy)
assert not np.isclose(predt_0, predt_1).all()
booster.load_model(path)
predt_2 = booster.predict(Xy)
np.testing.assert_allclose(predt_0, predt_2)
def test_feature_names_validation(self):
X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,))