Add support inference on SYCL devices (#9800)
--------- Co-authored-by: Dmitry Razdoburdin <> Co-authored-by: Nikolay Petrov <nikolay.a.petrov@intel.com> Co-authored-by: Alexandra <alexandra.epanchinzeva@intel.com>
This commit is contained in:
committed by
GitHub
parent
7196c9d95e
commit
381f1d3dc9
@@ -250,9 +250,15 @@ struct Context : public XGBoostParameter<Context> {
|
||||
default:
|
||||
// Do not use the device name as this is likely an internal error, the name
|
||||
// wouldn't be valid.
|
||||
LOG(FATAL) << "Unknown device type:"
|
||||
<< static_cast<std::underlying_type_t<DeviceOrd::Type>>(this->Device().device);
|
||||
break;
|
||||
if (this->Device().IsSycl()) {
|
||||
LOG(WARNING) << "The requested feature doesn't have SYCL specific implementation yet. "
|
||||
<< "CPU implementation is used";
|
||||
return cpu_fn();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown device type:"
|
||||
<< static_cast<std::underlying_type_t<DeviceOrd::Type>>(this->Device().device);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::invoke_result_t<CPUFn>();
|
||||
}
|
||||
@@ -262,7 +268,6 @@ struct Context : public XGBoostParameter<Context> {
|
||||
*/
|
||||
template <typename CPUFn, typename CUDAFn, typename SYCLFn>
|
||||
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn, SYCLFn&& sycl_fn) const {
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<SYCLFn>>);
|
||||
if (this->Device().IsSycl()) {
|
||||
return sycl_fn();
|
||||
|
||||
@@ -92,8 +92,8 @@ class Predictor {
|
||||
* \param out_predt Prediction vector to be initialized.
|
||||
* \param model Tree model used for prediction.
|
||||
*/
|
||||
void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
|
||||
const gbm::GBTreeModel& model) const;
|
||||
virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
|
||||
const gbm::GBTreeModel& model) const;
|
||||
|
||||
/**
|
||||
* \brief Generate batch predictions for a given feature matrix. May use
|
||||
|
||||
Reference in New Issue
Block a user