Add support inference on SYCL devices (#9800)

---------

Co-authored-by: Dmitry Razdoburdin <>
Co-authored-by: Nikolay Petrov <nikolay.a.petrov@intel.com>
Co-authored-by: Alexandra <alexandra.epanchinzeva@intel.com>
This commit is contained in:
Dmitry Razdoburdin
2023-12-04 09:15:57 +01:00
committed by GitHub
parent 7196c9d95e
commit 381f1d3dc9
31 changed files with 1369 additions and 1294 deletions

View File

@@ -16,6 +16,10 @@ if(USE_CUDA)
target_sources(objxgboost PRIVATE ${CUDA_SOURCES})
endif()
if(PLUGIN_SYCL)
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_SYCL=1)
endif()
target_include_directories(objxgboost
PRIVATE
${xgboost_SOURCE_DIR}/include

View File

@@ -169,10 +169,10 @@ inline void AssertNCCLSupport() {
#endif // !defined(XGBOOST_USE_NCCL)
}
inline void AssertOneAPISupport() {
#ifndef XGBOOST_USE_ONEAPI
LOG(FATAL) << "XGBoost version not compiled with OneAPI support.";
#endif // XGBOOST_USE_ONEAPI
inline void AssertSYCLSupport() {
#ifndef XGBOOST_USE_SYCL
LOG(FATAL) << "XGBoost version not compiled with SYCL support.";
#endif // XGBOOST_USE_SYCL
}
void SetDevice(std::int32_t device);

View File

@@ -113,13 +113,13 @@ void GBTree::Configure(Args const& cfg) {
}
#endif // defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_ONEAPI)
if (!oneapi_predictor_) {
oneapi_predictor_ =
std::unique_ptr<Predictor>(Predictor::Create("oneapi_predictor", this->ctx_));
#if defined(XGBOOST_USE_SYCL)
if (!sycl_predictor_) {
sycl_predictor_ =
std::unique_ptr<Predictor>(Predictor::Create("sycl_predictor", this->ctx_));
}
oneapi_predictor_->Configure(cfg);
#endif // defined(XGBOOST_USE_ONEAPI)
sycl_predictor_->Configure(cfg);
#endif // defined(XGBOOST_USE_SYCL)
// `updater` parameter was manually specified
specified_updater_ =
@@ -553,6 +553,11 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
},
[&, begin = tree_begin, end = tree_end] {
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
#if defined(XGBOOST_USE_SYCL)
},
[&, begin = tree_begin, end = tree_end] {
return this->sycl_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
#endif // defined(XGBOOST_USE_SYCL)
});
if (!known_type) {
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
@@ -568,10 +573,16 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
if (f_dmat && !f_dmat->SingleColBlock()) {
if (ctx_->IsCPU()) {
return cpu_predictor_;
} else {
} else if (ctx_->IsCUDA()) {
common::AssertGPUSupport();
CHECK(gpu_predictor_);
return gpu_predictor_;
} else {
#if defined(XGBOOST_USE_SYCL)
common::AssertSYCLSupport();
CHECK(sycl_predictor_);
return sycl_predictor_;
#endif // defined(XGBOOST_USE_SYCL)
}
}
@@ -606,10 +617,16 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
if (ctx_->IsCPU()) {
return cpu_predictor_;
} else {
} else if (ctx_->IsCUDA()) {
common::AssertGPUSupport();
CHECK(gpu_predictor_);
return gpu_predictor_;
} else {
#if defined(XGBOOST_USE_SYCL)
common::AssertSYCLSupport();
CHECK(sycl_predictor_);
return sycl_predictor_;
#endif // defined(XGBOOST_USE_SYCL)
}
return cpu_predictor_;
@@ -814,6 +831,11 @@ class Dart : public GBTree {
},
[&] {
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
#if defined(XGBOOST_USE_SYCL)
},
[&] {
return sycl_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
#endif // defined(XGBOOST_USE_SYCL)
});
CHECK(success) << msg;
};
@@ -830,6 +852,12 @@ class Dart : public GBTree {
[&] {
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
#if defined(XGBOOST_USE_SYCL)
},
[&] {
this->sycl_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
#endif // defined(XGBOOST_USE_SYCL)
});
}
// Multiple the tree weight

View File

@@ -349,9 +349,9 @@ class GBTree : public GradientBooster {
// Predictors
std::unique_ptr<Predictor> cpu_predictor_;
std::unique_ptr<Predictor> gpu_predictor_{nullptr};
#if defined(XGBOOST_USE_ONEAPI)
std::unique_ptr<Predictor> oneapi_predictor_;
#endif // defined(XGBOOST_USE_ONEAPI)
#if defined(XGBOOST_USE_SYCL)
std::unique_ptr<Predictor> sycl_predictor_;
#endif // defined(XGBOOST_USE_SYCL)
common::Monitor monitor_;
};