[SYCL] Add dask support for distributed (#10812)

This commit is contained in:
Dmitry Razdoburdin
2024-09-21 20:01:57 +02:00
committed by GitHub
parent 2a37a8880c
commit d7599e095b
10 changed files with 219 additions and 6 deletions

View File

@@ -31,6 +31,33 @@ template void InitHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Copy histogram from src to dst
*/
template<typename GradientSumT>
void CopyHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src,
size_t size) {
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst->Data());
const GradientSumT* psrc = reinterpret_cast<const GradientSumT*>(src.DataConst());
qu.submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(2 * size), [=](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
pdst[i] = psrc[i];
});
}).wait();
}
template void CopyHist(::sycl::queue qu,
GHistRow<float, MemoryType::on_device>* dst,
const GHistRow<float, MemoryType::on_device>& src,
size_t size);
template void CopyHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* dst,
const GHistRow<double, MemoryType::on_device>& src,
size_t size);
/*!
* \brief Compute Subtraction: dst = src1 - src2
*/

View File

@@ -36,6 +36,15 @@ void InitHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Copy histogram from src to dst
*/
template<typename GradientSumT>
void CopyHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src,
size_t size);
/*!
* \brief Compute subtraction: dst = src1 - src2
*/