Fix cache with gc (#8851)

- Make DMatrixCache thread-safe.
- Remove the use of thread-local memory.
This commit is contained in:
Jiaming Yuan
2023-03-01 00:39:06 +08:00
committed by GitHub
parent d9688f93c7
commit d54ef56f6f
4 changed files with 143 additions and 65 deletions

View File

@@ -328,9 +328,6 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
using LearnerAPIThreadLocalStore =
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
using ThreadLocalPredictionCache =
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
namespace {
StringView ModelMsg() {
return StringView{
@@ -368,6 +365,8 @@ class LearnerConfiguration : public Learner {
LearnerModelParam learner_model_param_;
LearnerTrainParam tparam_;
// Initial prediction.
PredictionContainer prediction_container_;
std::vector<std::string> metric_names_;
void ConfigureModelParamWithoutBaseScore() {
@@ -426,22 +425,15 @@ class LearnerConfiguration : public Learner {
}
public:
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix>> cache)
: need_configuration_{true} {
monitor_.Init("Learner");
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
for (std::shared_ptr<DMatrix> const& d : cache) {
if (d) {
local_cache.Cache(d, Context::kCpuId);
prediction_container_.Cache(d, Context::kCpuId);
}
}
}
~LearnerConfiguration() override {
auto local_cache = ThreadLocalPredictionCache::Get();
if (local_cache->find(this) != local_cache->cend()) {
local_cache->erase(this);
}
}
// Configuration before data is known.
void Configure() override {
@@ -499,10 +491,6 @@ class LearnerConfiguration : public Learner {
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted();
}
virtual PredictionContainer* GetPredictionCache() const {
return &((*ThreadLocalPredictionCache::Get())[this]);
}
void LoadConfig(Json const& in) override {
// If configuration is loaded, ensure that the model came from the same version
CHECK(IsA<Object>(in));
@@ -741,11 +729,10 @@ class LearnerConfiguration : public Learner {
if (mparam_.num_feature == 0) {
// TODO(hcho3): Change num_feature to 64-bit integer
unsigned num_feature = 0;
auto local_cache = this->GetPredictionCache();
for (auto& matrix : local_cache->Container()) {
CHECK(matrix.first);
for (auto const& matrix : prediction_container_.Container()) {
CHECK(matrix.first.ptr);
CHECK(!matrix.second.ref.expired());
const uint64_t num_col = matrix.first->Info().num_col_;
const uint64_t num_col = matrix.first.ptr->Info().num_col_;
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater";
@@ -817,13 +804,13 @@ class LearnerConfiguration : public Learner {
*/
void ConfigureTargets() {
CHECK(this->obj_);
auto const& cache = this->GetPredictionCache()->Container();
auto const& cache = prediction_container_.Container();
size_t n_targets = 1;
for (auto const& d : cache) {
if (n_targets == 1) {
n_targets = this->obj_->Targets(d.first->Info());
n_targets = this->obj_->Targets(d.first.ptr->Info());
} else {
auto t = this->obj_->Targets(d.first->Info());
auto t = this->obj_->Targets(d.first.ptr->Info());
CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
}
}
@@ -1275,8 +1262,7 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache();
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true, 0, 0);
@@ -1303,8 +1289,7 @@ class LearnerImpl : public LearnerIO {
this->ValidateDMatrix(train.get(), true);
auto local_cache = this->GetPredictionCache();
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
monitor_.Stop("BoostOneIter");
}
@@ -1326,10 +1311,9 @@ class LearnerImpl : public LearnerIO {
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}
auto local_cache = this->GetPredictionCache();
for (size_t i = 0; i < data_sets.size(); ++i) {
std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = local_cache->Cache(m, ctx_.gpu_id);
auto &predt = prediction_container_.Cache(m, ctx_.gpu_id);
this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false, 0, 0);
@@ -1370,8 +1354,7 @@ class LearnerImpl : public LearnerIO {
} else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
} else {
auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, ctx_.gpu_id);
auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(ctx_.gpu_id);