add cuda to hip wrapper
This commit is contained in:
@@ -26,22 +26,12 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#endif
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.resize(size);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault));
|
||||
AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
|
||||
@@ -49,7 +39,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
@@ -57,15 +46,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
hipMemcpyDefault));
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
hipMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@@ -74,11 +54,7 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
@@ -92,25 +68,15 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank_) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
hipMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
cudaMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
|
||||
Reference in New Issue
Block a user