Revert OMP guard. (#6987)
The guard protects the global variable from being changed by XGBoost. But this leads to a bug that the `n_threads` parameter is no longer used after the first iteration. This is due to the fact that `omp_set_num_threads` is only called once in `Learner::Configure` at the beginning of the training process. The guard is still useful for `gpu_id`, since this is called all the times in our codebase doesn't matter which iteration we are currently running.
This commit is contained in:
parent
cf06a266a8
commit
6e52aefb37
@ -163,7 +163,6 @@ inline float GetMissing(Json const &config) {
|
|||||||
|
|
||||||
// Safe guard some global variables from being changed by XGBoost.
|
// Safe guard some global variables from being changed by XGBoost.
|
||||||
class XGBoostAPIGuard {
|
class XGBoostAPIGuard {
|
||||||
int32_t n_threads_ {omp_get_max_threads()};
|
|
||||||
int32_t device_id_ {0};
|
int32_t device_id_ {0};
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
@ -179,7 +178,6 @@ class XGBoostAPIGuard {
|
|||||||
SetGPUAttribute();
|
SetGPUAttribute();
|
||||||
}
|
}
|
||||||
~XGBoostAPIGuard() {
|
~XGBoostAPIGuard() {
|
||||||
omp_set_num_threads(n_threads_);
|
|
||||||
RestoreGPUAttribute();
|
RestoreGPUAttribute();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -278,37 +278,4 @@ TEST(CAPI, XGBGlobalConfig) {
|
|||||||
ASSERT_EQ(err.find("verbosity"), std::string::npos);
|
ASSERT_EQ(err.find("verbosity"), std::string::npos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, GlobalVariables) {
|
|
||||||
size_t n_threads = omp_get_max_threads();
|
|
||||||
size_t constexpr kRows = 10;
|
|
||||||
bst_feature_t constexpr kCols = 2;
|
|
||||||
|
|
||||||
DMatrixHandle handle;
|
|
||||||
std::vector<float> data(kCols * kRows, 1.5);
|
|
||||||
|
|
||||||
|
|
||||||
ASSERT_EQ(XGDMatrixCreateFromMat_omp(data.data(), kRows, kCols,
|
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
|
||||||
&handle, 0),
|
|
||||||
0);
|
|
||||||
std::vector<float> labels(kRows, 2.0f);
|
|
||||||
ASSERT_EQ(XGDMatrixSetFloatInfo(handle, "label", labels.data(), labels.size()), 0);
|
|
||||||
|
|
||||||
DMatrixHandle m_handles[1];
|
|
||||||
m_handles[0] = handle;
|
|
||||||
|
|
||||||
BoosterHandle booster;
|
|
||||||
ASSERT_EQ(XGBoosterCreate(m_handles, 1, &booster), 0);
|
|
||||||
ASSERT_EQ(XGBoosterSetParam(booster, "nthread", "16"), 0);
|
|
||||||
|
|
||||||
omp_set_num_threads(1);
|
|
||||||
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, handle), 0);
|
|
||||||
ASSERT_EQ(omp_get_max_threads(), 1);
|
|
||||||
|
|
||||||
omp_set_num_threads(n_threads);
|
|
||||||
|
|
||||||
XGDMatrixFree(handle);
|
|
||||||
XGBoosterFree(booster);
|
|
||||||
}
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user