[R] Don't cap global number of threads for serialization (#10028)
This commit is contained in:
@@ -106,30 +106,13 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
|
||||
Validate(*this);
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::int32_t IOThreads(Context const* ctx) {
|
||||
CHECK(ctx);
|
||||
std::int32_t n_threads = ctx->Threads();
|
||||
// CRAN checks for number of threads used by examples, but we might not have the right
|
||||
// number of threads when serializing/unserializing models as nthread is a booster
|
||||
// parameter, which is only effective after booster initialization.
|
||||
//
|
||||
// The threshold ratio of CPU time to user time for R is 2.5, we set the number of
|
||||
// threads to 2.
|
||||
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
n_threads = std::min(2, n_threads);
|
||||
#endif
|
||||
return n_threads;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GBTreeModel::SaveModel(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
|
||||
out["gbtree_model_param"] = ToJson(param);
|
||||
std::vector<Json> trees_json(trees.size());
|
||||
|
||||
common::ParallelFor(trees.size(), IOThreads(ctx_), [&](auto t) {
|
||||
common::ParallelFor(trees.size(), ctx_->Threads(), [&](auto t) {
|
||||
auto const& tree = trees[t];
|
||||
Json jtree{Object{}};
|
||||
tree->SaveModel(&jtree);
|
||||
@@ -167,7 +150,7 @@ void GBTreeModel::LoadModel(Json const& in) {
|
||||
CHECK_EQ(tree_info_json.size(), param.num_trees);
|
||||
tree_info.resize(param.num_trees);
|
||||
|
||||
common::ParallelFor(param.num_trees, IOThreads(ctx_), [&](auto t) {
|
||||
common::ParallelFor(param.num_trees, ctx_->Threads(), [&](auto t) {
|
||||
auto tree_id = get<Integer const>(trees_json[t]["id"]);
|
||||
trees.at(tree_id).reset(new RegTree{});
|
||||
trees[tree_id]->LoadModel(trees_json[t]);
|
||||
|
||||
Reference in New Issue
Block a user