Update CUDA docker image and NCCL. (#8139)

This commit is contained in:
Jiaming Yuan
2022-08-07 16:32:41 +08:00
committed by GitHub
parent ff471b3fab
commit bcc8679a05
7 changed files with 43 additions and 15 deletions

View File

@@ -33,6 +33,9 @@ std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
void NcclAllReducer::DoInit(int _device_ordinal) {
int32_t const rank = rabit::GetRank();
int32_t const world = rabit::GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};

View File

@@ -754,6 +754,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
void Init(int _device_ordinal) {
device_ordinal_ = _device_ordinal;
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (rabit::GetWorldSize() == 1) {
return;
}
this->Underlying().DoInit(_device_ordinal);
initialised_ = true;
}
@@ -769,6 +772,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
*/
void AllGather(void const *data, size_t length_bytes, std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllGather(data, length_bytes, segments, recvbuf);
@@ -784,6 +790,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
*/
void AllGather(uint32_t const *data, size_t length,
dh::caching_device_vector<uint32_t> *recvbuf) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllGather(data, length, recvbuf);
@@ -798,6 +807,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
* \param count Number of elements.
*/
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
@@ -814,6 +826,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
* \param count Number of elements.
*/
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
@@ -831,6 +846,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
* \param count Number of.
*/
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
@@ -847,6 +865,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
* \param count Number of elements.
*/
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
@@ -863,6 +884,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
* \param count Number of elements.
*/
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
this->Underlying().DoAllReduceSum(sendbuff, recvbuff, count);
@@ -886,6 +910,9 @@ class AllReducerBase : public xgboost::common::Crtp<AllReducer> {
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
if (rabit::GetWorldSize() == 1) {
return;
}
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT