Use context in SetInfo. (#7687)

* Use the name `Context`.
* Pass a context object into `SetInfo`.
* Add context to proxy matrix.
* Add context to iterative DMatrix.

This is to remove the use of the default number of threads during `SetInfo` as a follow-up on
removing the global omp variable while preparing for CUDA stream semantic.  Currently, XGBoost
uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
Jiaming Yuan
2022-03-24 22:16:26 +08:00
committed by GitHub
parent f5b20286e2
commit 64575591d8
19 changed files with 142 additions and 142 deletions

View File

@@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
API_END();
}
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const bst_float* info,
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len);
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
API_END();
}
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
char const* field,
char const* interface_c_str) {
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field,
char const *interface_c_str) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, interface_c_str);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, interface_c_str);
API_END();
}
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* info,
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
API_END();
}
@@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END();
}
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void const *data, xgboost::bst_ulong size,
int type) {
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
xgboost::bst_ulong size, int type) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size);
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
API_END();
}