Use context in SetInfo. (#7687)
* Use the name `Context`. * Pass a context object into `SetInfo`. * Add context to proxy matrix. * Add context to iterative DMatrix. This is to remove the use of the default number of threads during `SetInfo` as a follow-up on removing the global omp variable while preparing for CUDA stream semantic. Currently, XGBoost uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
@@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const bst_float* info,
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
|
||||
char const* field,
|
||||
char const* interface_c_str) {
|
||||
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field,
|
||||
char const *interface_c_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, interface_c_str);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, interface_c_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* info,
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
||||
void const *data, xgboost::bst_ulong size,
|
||||
int type) {
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
|
||||
xgboost::bst_ulong size, int type) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
CHECK(type >= 1 && type <= 4);
|
||||
info.SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user