enable rocm, fix device_communicator_adapter.cuh
This commit is contained in:
parent
f2009533e1
commit
762fd9028d
@ -45,7 +45,12 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
#endif
|
||||||
|
|
||||||
int const world_size = communicator_->GetWorldSize();
|
int const world_size = communicator_->GetWorldSize();
|
||||||
int const rank = communicator_->GetRank();
|
int const rank = communicator_->GetRank();
|
||||||
|
|
||||||
@ -62,14 +67,25 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
|||||||
for (int32_t i = 0; i < world_size; ++i) {
|
for (int32_t i = 0; i < world_size; ++i) {
|
||||||
size_t as_bytes = segments->at(i);
|
size_t as_bytes = segments->at(i);
|
||||||
if (i == rank) {
|
if (i == rank) {
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||||
|
hipMemcpyDefault));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||||
cudaMemcpyDefault));
|
cudaMemcpyDefault));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||||
offset += as_bytes;
|
offset += as_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||||
|
hipMemcpyDefault));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||||
cudaMemcpyDefault));
|
cudaMemcpyDefault));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void Synchronize() override {
|
void Synchronize() override {
|
||||||
@ -83,12 +99,24 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
#endif
|
||||||
|
|
||||||
auto size = count * sizeof(T);
|
auto size = count * sizeof(T);
|
||||||
host_buffer_.reserve(size);
|
host_buffer_.reserve(size);
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault));
|
||||||
|
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||||
|
dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
int const device_ordinal_;
|
int const device_ordinal_;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user