Update CUDA docker image and NCCL. (#8139)
This commit is contained in:
@@ -33,6 +33,9 @@ std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||
void NcclAllReducer::DoInit(int _device_ordinal) {
|
||||
int32_t const rank = rabit::GetRank();
|
||||
int32_t const world = rabit::GetWorldSize();
|
||||
if (world == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
|
||||
@@ -754,6 +754,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
void Init(int _device_ordinal) {
|
||||
device_ordinal_ = _device_ordinal;
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
this->Underlying().DoInit(_device_ordinal);
|
||||
initialised_ = true;
|
||||
}
|
||||
@@ -769,6 +772,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
*/
|
||||
void AllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *recvbuf) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf);
|
||||
@@ -784,6 +790,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
*/
|
||||
void AllGather(uint32_t const *data, size_t length,
|
||||
dh::caching_device_vector<uint32_t> *recvbuf) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllGather(data, length, recvbuf);
|
||||
@@ -798,6 +807,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
@@ -814,6 +826,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
@@ -831,6 +846,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
* \param count Number of.
|
||||
*/
|
||||
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
@@ -847,6 +865,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
@@ -863,6 +884,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
|
||||
@@ -886,6 +910,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
|
||||
!std::is_same<size_t, unsigned long long>::value> // NOLINT
|
||||
* = nullptr>
|
||||
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
|
||||
if (rabit::GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
|
||||
|
||||
Reference in New Issue
Block a user