Add an API guard to prevent global variables being changed. (#6891)

This commit is contained in:
Jiaming Yuan
2021-04-23 10:27:57 +08:00
committed by GitHub
parent 896aede340
commit a2ecbdaa31
7 changed files with 96 additions and 24 deletions

View File

@@ -43,7 +43,7 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
}
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
API_BEGIN_UNGUARD();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback);
API_END();
@@ -568,9 +568,9 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
bst_float *grad,
bst_float *hess,
xgboost::bst_ulong len) {
HostDeviceVector<GradientPair> tmp_gpair;
API_BEGIN();
CHECK_HANDLE();
HostDeviceVector<GradientPair> tmp_gpair;
auto* bst = static_cast<Learner*>(handle);
auto* dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);

View File

@@ -6,6 +6,24 @@
#include "c_api_utils.h"
#include "../data/device_adapter.cuh"
namespace xgboost {
void XGBoostAPIGuard::SetGPUAttribute() {
try {
device_id_ = dh::CurrentDevice();
} catch (dmlc::Error const&) {
// do nothing, running on CPU only machine
}
}
void XGBoostAPIGuard::RestoreGPUAttribute() {
try {
dh::safe_cuda(cudaSetDevice(device_id_));
} catch (dmlc::Error const&) {
// do nothing, running on CPU only machine
}
}
} // namespace xgboost
using namespace xgboost; // NOLINT
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,

View File

@@ -9,16 +9,24 @@
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include "c_api_utils.h"
/*! \brief macro to guard beginning and end section of all functions */
#ifdef LOG_CAPI_INVOCATION
#define API_BEGIN() \
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; try {
#define API_BEGIN() \
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; \
try { \
auto __guard = ::xgboost::XGBoostAPIGuard();
#else // LOG_CAPI_INVOCATION
#define API_BEGIN() try {
#define API_BEGIN() \
try { \
auto __guard = ::xgboost::XGBoostAPIGuard();
#define API_BEGIN_UNGUARD() try {
#endif // LOG_CAPI_INVOCATION
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
and finishes with API_END() */
#define API_END() \
} catch (dmlc::Error & _except_) { \
return XGBAPIHandleException(_except_); \
@@ -29,12 +37,6 @@
#define CHECK_HANDLE() if (handle == nullptr) \
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief Set the last error message needed by C API

View File

@@ -13,8 +13,6 @@
#include "xgboost/learner.h"
#include "xgboost/c_api.h"
#include "c_api_error.h"
namespace xgboost {
/* \brief Determine the output shape of prediction.
*
@@ -158,5 +156,28 @@ inline float GetMissing(Json const &config) {
}
return missing;
}
// Safe guard some global variables from being changed by XGBoost.
class XGBoostAPIGuard {
int32_t n_threads_ {omp_get_max_threads()};
int32_t device_id_ {0};
#if defined(XGBOOST_USE_CUDA)
void SetGPUAttribute();
void RestoreGPUAttribute();
#else
void SetGPUAttribute() {}
void RestoreGPUAttribute() {}
#endif
public:
XGBoostAPIGuard() {
SetGPUAttribute();
}
~XGBoostAPIGuard() {
omp_set_num_threads(n_threads_);
RestoreGPUAttribute();
}
};
} // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_