diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh index ae3b3f581..ee6306c15 100644 --- a/src/collective/device_communicator_adapter.cuh +++ b/src/collective/device_communicator_adapter.cuh @@ -45,7 +45,12 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { return; } +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_ordinal_)); +#else dh::safe_cuda(cudaSetDevice(device_ordinal_)); +#endif + int const world_size = communicator_->GetWorldSize(); int const rank = communicator_->GetRank(); @@ -62,14 +67,25 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { for (int32_t i = 0; i < world_size; ++i) { size_t as_bytes = segments->at(i); if (i == rank) { +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank), + hipMemcpyDefault)); +#else dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank), cudaMemcpyDefault)); +#endif } communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i); offset += as_bytes; } + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, + hipMemcpyDefault)); +#else dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, cudaMemcpyDefault)); +#endif } void Synchronize() override { @@ -83,12 +99,24 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator { return; } +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_ordinal_)); +#else dh::safe_cuda(cudaSetDevice(device_ordinal_)); +#endif + auto size = count * sizeof(T); host_buffer_.reserve(size); + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault)); + communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum); + dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault)); +#else dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum); dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); +#endif } int const device_ordinal_;