Add support inference on SYCL devices (#9800)
--------- Co-authored-by: Dmitry Razdoburdin <> Co-authored-by: Nikolay Petrov <nikolay.a.petrov@intel.com> Co-authored-by: Alexandra <alexandra.epanchinzeva@intel.com>
This commit is contained in:
committed by
GitHub
parent
7196c9d95e
commit
381f1d3dc9
@@ -16,6 +16,10 @@ if(USE_CUDA)
|
||||
target_sources(objxgboost PRIVATE ${CUDA_SOURCES})
|
||||
endif()
|
||||
|
||||
if(PLUGIN_SYCL)
|
||||
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_SYCL=1)
|
||||
endif()
|
||||
|
||||
target_include_directories(objxgboost
|
||||
PRIVATE
|
||||
${xgboost_SOURCE_DIR}/include
|
||||
|
||||
@@ -169,10 +169,10 @@ inline void AssertNCCLSupport() {
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
}
|
||||
|
||||
inline void AssertOneAPISupport() {
|
||||
#ifndef XGBOOST_USE_ONEAPI
|
||||
LOG(FATAL) << "XGBoost version not compiled with OneAPI support.";
|
||||
#endif // XGBOOST_USE_ONEAPI
|
||||
inline void AssertSYCLSupport() {
|
||||
#ifndef XGBOOST_USE_SYCL
|
||||
LOG(FATAL) << "XGBoost version not compiled with SYCL support.";
|
||||
#endif // XGBOOST_USE_SYCL
|
||||
}
|
||||
|
||||
void SetDevice(std::int32_t device);
|
||||
|
||||
@@ -113,13 +113,13 @@ void GBTree::Configure(Args const& cfg) {
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
#if defined(XGBOOST_USE_ONEAPI)
|
||||
if (!oneapi_predictor_) {
|
||||
oneapi_predictor_ =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("oneapi_predictor", this->ctx_));
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
if (!sycl_predictor_) {
|
||||
sycl_predictor_ =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("sycl_predictor", this->ctx_));
|
||||
}
|
||||
oneapi_predictor_->Configure(cfg);
|
||||
#endif // defined(XGBOOST_USE_ONEAPI)
|
||||
sycl_predictor_->Configure(cfg);
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
|
||||
// `updater` parameter was manually specified
|
||||
specified_updater_ =
|
||||
@@ -553,6 +553,11 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
|
||||
},
|
||||
[&, begin = tree_begin, end = tree_end] {
|
||||
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
},
|
||||
[&, begin = tree_begin, end = tree_end] {
|
||||
return this->sycl_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
});
|
||||
if (!known_type) {
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
|
||||
@@ -568,10 +573,16 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
|
||||
if (f_dmat && !f_dmat->SingleColBlock()) {
|
||||
if (ctx_->IsCPU()) {
|
||||
return cpu_predictor_;
|
||||
} else {
|
||||
} else if (ctx_->IsCUDA()) {
|
||||
common::AssertGPUSupport();
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
common::AssertSYCLSupport();
|
||||
CHECK(sycl_predictor_);
|
||||
return sycl_predictor_;
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -606,10 +617,16 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
|
||||
|
||||
if (ctx_->IsCPU()) {
|
||||
return cpu_predictor_;
|
||||
} else {
|
||||
} else if (ctx_->IsCUDA()) {
|
||||
common::AssertGPUSupport();
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
common::AssertSYCLSupport();
|
||||
CHECK(sycl_predictor_);
|
||||
return sycl_predictor_;
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
}
|
||||
|
||||
return cpu_predictor_;
|
||||
@@ -814,6 +831,11 @@ class Dart : public GBTree {
|
||||
},
|
||||
[&] {
|
||||
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
},
|
||||
[&] {
|
||||
return sycl_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
});
|
||||
CHECK(success) << msg;
|
||||
};
|
||||
@@ -830,6 +852,12 @@ class Dart : public GBTree {
|
||||
[&] {
|
||||
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||
model_);
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
},
|
||||
[&] {
|
||||
this->sycl_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||
model_);
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
});
|
||||
}
|
||||
// Multiple the tree weight
|
||||
|
||||
@@ -349,9 +349,9 @@ class GBTree : public GradientBooster {
|
||||
// Predictors
|
||||
std::unique_ptr<Predictor> cpu_predictor_;
|
||||
std::unique_ptr<Predictor> gpu_predictor_{nullptr};
|
||||
#if defined(XGBOOST_USE_ONEAPI)
|
||||
std::unique_ptr<Predictor> oneapi_predictor_;
|
||||
#endif // defined(XGBOOST_USE_ONEAPI)
|
||||
#if defined(XGBOOST_USE_SYCL)
|
||||
std::unique_ptr<Predictor> sycl_predictor_;
|
||||
#endif // defined(XGBOOST_USE_SYCL)
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user