Add an API guard to prevent global variables being changed. (#6891)

This commit is contained in:
Jiaming Yuan 2021-04-23 10:27:57 +08:00 committed by GitHub
parent 896aede340
commit a2ecbdaa31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 24 deletions

View File

@ -1939,8 +1939,11 @@ class Booster(object):
)
)
return _prediction_output(shape, dims, preds, False)
if lazy_isinstance(data, "cupy.core.core", "ndarray"):
if lazy_isinstance(data, "cupy.core.core", "ndarray") or lazy_isinstance(
data, "cupy._core.core", "ndarray"
):
from .data import _transform_cupy_array
data = _transform_cupy_array(data)
interface = data.__cuda_array_interface__
if "mask" in interface:

View File

@ -43,7 +43,7 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
}
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
API_BEGIN_UNGUARD();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback);
API_END();
@ -568,9 +568,9 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
bst_float *grad,
bst_float *hess,
xgboost::bst_ulong len) {
HostDeviceVector<GradientPair> tmp_gpair;
API_BEGIN();
CHECK_HANDLE();
HostDeviceVector<GradientPair> tmp_gpair;
auto* bst = static_cast<Learner*>(handle);
auto* dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);

View File

@ -6,6 +6,24 @@
#include "c_api_utils.h"
#include "../data/device_adapter.cuh"
namespace xgboost {
void XGBoostAPIGuard::SetGPUAttribute() {
try {
device_id_ = dh::CurrentDevice();
} catch (dmlc::Error const&) {
// do nothing, running on CPU only machine
}
}
void XGBoostAPIGuard::RestoreGPUAttribute() {
try {
dh::safe_cuda(cudaSetDevice(device_id_));
} catch (dmlc::Error const&) {
// do nothing, running on CPU only machine
}
}
} // namespace xgboost
using namespace xgboost; // NOLINT
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,

View File

@ -9,16 +9,24 @@
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include "c_api_utils.h"
/*! \brief macro to guard beginning and end section of all functions */
#ifdef LOG_CAPI_INVOCATION
#define API_BEGIN() \
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; try {
#define API_BEGIN() \
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; \
try { \
auto __guard = ::xgboost::XGBoostAPIGuard();
#else // LOG_CAPI_INVOCATION
#define API_BEGIN() try {
#define API_BEGIN() \
try { \
auto __guard = ::xgboost::XGBoostAPIGuard();
#define API_BEGIN_UNGUARD() try {
#endif // LOG_CAPI_INVOCATION
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
and finishes with API_END() */
#define API_END() \
} catch (dmlc::Error & _except_) { \
return XGBAPIHandleException(_except_); \
@ -29,12 +37,6 @@
#define CHECK_HANDLE() if (handle == nullptr) \
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief Set the last error message needed by C API

View File

@ -13,8 +13,6 @@
#include "xgboost/learner.h"
#include "xgboost/c_api.h"
#include "c_api_error.h"
namespace xgboost {
/* \brief Determine the output shape of prediction.
*
@ -158,5 +156,28 @@ inline float GetMissing(Json const &config) {
}
return missing;
}
// Safe guard some global variables from being changed by XGBoost.
class XGBoostAPIGuard {
int32_t n_threads_ {omp_get_max_threads()};
int32_t device_id_ {0};
#if defined(XGBOOST_USE_CUDA)
void SetGPUAttribute();
void RestoreGPUAttribute();
#else
void SetGPUAttribute() {}
void RestoreGPUAttribute() {}
#endif
public:
XGBoostAPIGuard() {
SetGPUAttribute();
}
~XGBoostAPIGuard() {
omp_set_num_threads(n_threads_);
RestoreGPUAttribute();
}
};
} // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -12,7 +12,6 @@
namespace dh {
#if __CUDACC_VER_MAJOR__ > 9
constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
@ -31,16 +30,13 @@ std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
return ss.str();
}
#endif // __CUDACC_VER_MAJOR__ > 9
void AllReducer::Init(int _device_ordinal) {
#ifdef XGBOOST_USE_NCCL
LOG(DEBUG) << "Running nccl init on: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__;
device_ordinal_ = _device_ordinal;
int32_t const rank = rabit::GetRank();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
#if __CUDACC_VER_MAJOR__ > 9
int32_t const rank = rabit::GetRank();
int32_t const world = rabit::GetWorldSize();
std::vector<uint64_t> uuids(world * kUuidLength, 0);
@ -61,13 +57,13 @@ void AllReducer::Init(int _device_ordinal) {
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported";
#endif // __CUDACC_VER_MAJOR__ > 9
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank));
safe_cuda(cudaStreamCreate(&stream_));
initialised_ = true;

View File

@ -279,4 +279,36 @@ TEST(CAPI, XGBGlobalConfig) {
}
}
TEST(CAPI, GlobalVariables) {
size_t n_threads = omp_get_max_threads();
size_t constexpr kRows = 10;
bst_feature_t constexpr kCols = 2;
DMatrixHandle handle;
std::vector<float> data(kCols * kRows, 1.5);
ASSERT_EQ(XGDMatrixCreateFromMat_omp(data.data(), kRows, kCols,
std::numeric_limits<float>::quiet_NaN(),
&handle, 0),
0);
std::vector<float> labels(kRows, 2.0f);
ASSERT_EQ(XGDMatrixSetFloatInfo(handle, "label", labels.data(), labels.size()), 0);
DMatrixHandle m_handles[1];
m_handles[0] = handle;
BoosterHandle booster;
ASSERT_EQ(XGBoosterCreate(m_handles, 1, &booster), 0);
ASSERT_EQ(XGBoosterSetParam(booster, "nthread", "16"), 0);
omp_set_num_threads(1);
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, handle), 0);
ASSERT_EQ(omp_get_max_threads(), 1);
omp_set_num_threads(n_threads);
XGDMatrixFree(handle);
XGBoosterFree(booster);
}
} // namespace xgboost