Add an API guard to prevent global variables being changed. (#6891)
This commit is contained in:
parent
896aede340
commit
a2ecbdaa31
@ -1939,8 +1939,11 @@ class Booster(object):
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
if lazy_isinstance(data, "cupy.core.core", "ndarray"):
|
||||
if lazy_isinstance(data, "cupy.core.core", "ndarray") or lazy_isinstance(
|
||||
data, "cupy._core.core", "ndarray"
|
||||
):
|
||||
from .data import _transform_cupy_array
|
||||
|
||||
data = _transform_cupy_array(data)
|
||||
interface = data.__cuda_array_interface__
|
||||
if "mask" in interface:
|
||||
|
||||
@ -43,7 +43,7 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||
}
|
||||
|
||||
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
||||
API_BEGIN();
|
||||
API_BEGIN_UNGUARD();
|
||||
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
|
||||
registry->Register(callback);
|
||||
API_END();
|
||||
@ -568,9 +568,9 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
bst_float *grad,
|
||||
bst_float *hess,
|
||||
xgboost::bst_ulong len) {
|
||||
HostDeviceVector<GradientPair> tmp_gpair;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
HostDeviceVector<GradientPair> tmp_gpair;
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
auto* dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
|
||||
@ -6,6 +6,24 @@
|
||||
#include "c_api_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
void XGBoostAPIGuard::SetGPUAttribute() {
|
||||
try {
|
||||
device_id_ = dh::CurrentDevice();
|
||||
} catch (dmlc::Error const&) {
|
||||
// do nothing, running on CPU only machine
|
||||
}
|
||||
}
|
||||
|
||||
void XGBoostAPIGuard::RestoreGPUAttribute() {
|
||||
try {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
} catch (dmlc::Error const&) {
|
||||
// do nothing, running on CPU only machine
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
|
||||
@ -9,16 +9,24 @@
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
#include "c_api_utils.h"
|
||||
|
||||
/*! \brief macro to guard beginning and end section of all functions */
|
||||
#ifdef LOG_CAPI_INVOCATION
|
||||
#define API_BEGIN() \
|
||||
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; try {
|
||||
#define API_BEGIN() \
|
||||
LOG(CONSOLE) << "[XGBoost C API invocation] " << __PRETTY_FUNCTION__; \
|
||||
try { \
|
||||
auto __guard = ::xgboost::XGBoostAPIGuard();
|
||||
#else // LOG_CAPI_INVOCATION
|
||||
#define API_BEGIN() try {
|
||||
#define API_BEGIN() \
|
||||
try { \
|
||||
auto __guard = ::xgboost::XGBoostAPIGuard();
|
||||
|
||||
#define API_BEGIN_UNGUARD() try {
|
||||
#endif // LOG_CAPI_INVOCATION
|
||||
|
||||
/*! \brief every function starts with API_BEGIN();
|
||||
and finishes with API_END() or API_END_HANDLE_ERROR */
|
||||
and finishes with API_END() */
|
||||
#define API_END() \
|
||||
} catch (dmlc::Error & _except_) { \
|
||||
return XGBAPIHandleException(_except_); \
|
||||
@ -29,12 +37,6 @@
|
||||
|
||||
#define CHECK_HANDLE() if (handle == nullptr) \
|
||||
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
|
||||
/*!
|
||||
* \brief every function starts with API_BEGIN();
|
||||
* and finishes with API_END() or API_END_HANDLE_ERROR
|
||||
* The finally clause contains procedure to cleanup states when an error happens.
|
||||
*/
|
||||
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
||||
|
||||
/*!
|
||||
* \brief Set the last error message needed by C API
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
*
|
||||
@ -158,5 +156,28 @@ inline float GetMissing(Json const &config) {
|
||||
}
|
||||
return missing;
|
||||
}
|
||||
|
||||
// Safe guard some global variables from being changed by XGBoost.
|
||||
class XGBoostAPIGuard {
|
||||
int32_t n_threads_ {omp_get_max_threads()};
|
||||
int32_t device_id_ {0};
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
void SetGPUAttribute();
|
||||
void RestoreGPUAttribute();
|
||||
#else
|
||||
void SetGPUAttribute() {}
|
||||
void RestoreGPUAttribute() {}
|
||||
#endif
|
||||
|
||||
public:
|
||||
XGBoostAPIGuard() {
|
||||
SetGPUAttribute();
|
||||
}
|
||||
~XGBoostAPIGuard() {
|
||||
omp_set_num_threads(n_threads_);
|
||||
RestoreGPUAttribute();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
|
||||
namespace dh {
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ > 9
|
||||
constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
@ -31,16 +30,13 @@ std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
#endif // __CUDACC_VER_MAJOR__ > 9
|
||||
|
||||
void AllReducer::Init(int _device_ordinal) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
LOG(DEBUG) << "Running nccl init on: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__;
|
||||
|
||||
device_ordinal_ = _device_ordinal;
|
||||
int32_t const rank = rabit::GetRank();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ > 9
|
||||
int32_t const rank = rabit::GetRank();
|
||||
int32_t const world = rabit::GetWorldSize();
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
@ -61,13 +57,13 @@ void AllReducer::Init(int _device_ordinal) {
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported";
|
||||
#endif // __CUDACC_VER_MAJOR__ > 9
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
|
||||
id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank));
|
||||
safe_cuda(cudaStreamCreate(&stream_));
|
||||
initialised_ = true;
|
||||
|
||||
@ -279,4 +279,36 @@ TEST(CAPI, XGBGlobalConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CAPI, GlobalVariables) {
|
||||
size_t n_threads = omp_get_max_threads();
|
||||
size_t constexpr kRows = 10;
|
||||
bst_feature_t constexpr kCols = 2;
|
||||
|
||||
DMatrixHandle handle;
|
||||
std::vector<float> data(kCols * kRows, 1.5);
|
||||
|
||||
|
||||
ASSERT_EQ(XGDMatrixCreateFromMat_omp(data.data(), kRows, kCols,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&handle, 0),
|
||||
0);
|
||||
std::vector<float> labels(kRows, 2.0f);
|
||||
ASSERT_EQ(XGDMatrixSetFloatInfo(handle, "label", labels.data(), labels.size()), 0);
|
||||
|
||||
DMatrixHandle m_handles[1];
|
||||
m_handles[0] = handle;
|
||||
|
||||
BoosterHandle booster;
|
||||
ASSERT_EQ(XGBoosterCreate(m_handles, 1, &booster), 0);
|
||||
ASSERT_EQ(XGBoosterSetParam(booster, "nthread", "16"), 0);
|
||||
|
||||
omp_set_num_threads(1);
|
||||
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, handle), 0);
|
||||
ASSERT_EQ(omp_get_max_threads(), 1);
|
||||
|
||||
omp_set_num_threads(n_threads);
|
||||
|
||||
XGDMatrixFree(handle);
|
||||
XGBoosterFree(booster);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user