Add an API guard to prevent global variables being changed. (#6891)
This commit is contained in:
@@ -43,7 +43,7 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||
}
|
||||
|
||||
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
||||
API_BEGIN();
|
||||
API_BEGIN_UNGUARD();
|
||||
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
|
||||
registry->Register(callback);
|
||||
API_END();
|
||||
@@ -568,9 +568,9 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
bst_float *grad,
|
||||
bst_float *hess,
|
||||
xgboost::bst_ulong len) {
|
||||
HostDeviceVector<GradientPair> tmp_gpair;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
HostDeviceVector<GradientPair> tmp_gpair;
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
auto* dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
|
||||
@@ -6,6 +6,24 @@
|
||||
#include "c_api_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
void XGBoostAPIGuard::SetGPUAttribute() {
|
||||
try {
|
||||
device_id_ = dh::CurrentDevice();
|
||||
} catch (dmlc::Error const&) {
|
||||
// do nothing, running on CPU only machine
|
||||
}
|
||||
}
|
||||
|
||||
void XGBoostAPIGuard::RestoreGPUAttribute() {
|
||||
try {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
} catch (dmlc::Error const&) {
|
||||
// do nothing, running on CPU only machine
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
|
||||
@@ -9,16 +9,24 @@
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
#include "c_api_utils.h"
|
||||
|
||||
/*! \brief macro to guard beginning and end section of all functions */
|
||||
#ifdef LOG_CAPI_INVOCATION
|
||||
#define API_BEGIN() \
|
||||
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; try {
|
||||
#define API_BEGIN() \
|
||||
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; \
|
||||
try { \
|
||||
auto __guard = ::xgboost::XGBoostAPIGuard();
|
||||
#else // LOG_CAPI_INVOCATION
|
||||
#define API_BEGIN() try {
|
||||
#define API_BEGIN() \
|
||||
try { \
|
||||
auto __guard = ::xgboost::XGBoostAPIGuard();
|
||||
|
||||
#define API_BEGIN_UNGUARD() try {
|
||||
#endif // LOG_CAPI_INVOCATION
|
||||
|
||||
/*! \brief every function starts with API_BEGIN();
|
||||
and finishes with API_END() or API_END_HANDLE_ERROR */
|
||||
and finishes with API_END() */
|
||||
#define API_END() \
|
||||
} catch (dmlc::Error & _except_) { \
|
||||
return XGBAPIHandleException(_except_); \
|
||||
@@ -29,12 +37,6 @@
|
||||
|
||||
#define CHECK_HANDLE() if (handle == nullptr) \
|
||||
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
|
||||
/*!
|
||||
* \brief every function starts with API_BEGIN();
|
||||
* and finishes with API_END() or API_END_HANDLE_ERROR
|
||||
* The finally clause contains procedure to cleanup states when an error happens.
|
||||
*/
|
||||
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
||||
|
||||
/*!
|
||||
* \brief Set the last error message needed by C API
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
*
|
||||
@@ -158,5 +156,28 @@ inline float GetMissing(Json const &config) {
|
||||
}
|
||||
return missing;
|
||||
}
|
||||
|
||||
// Safe guard some global variables from being changed by XGBoost.
|
||||
class XGBoostAPIGuard {
|
||||
int32_t n_threads_ {omp_get_max_threads()};
|
||||
int32_t device_id_ {0};
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
void SetGPUAttribute();
|
||||
void RestoreGPUAttribute();
|
||||
#else
|
||||
void SetGPUAttribute() {}
|
||||
void RestoreGPUAttribute() {}
|
||||
#endif
|
||||
|
||||
public:
|
||||
XGBoostAPIGuard() {
|
||||
SetGPUAttribute();
|
||||
}
|
||||
~XGBoostAPIGuard() {
|
||||
omp_set_num_threads(n_threads_);
|
||||
RestoreGPUAttribute();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user