Clean up MGPU C++ tests (#9430)

This commit is contained in:
Rong Ou
2023-08-01 23:31:18 -07:00
committed by GitHub
parent a9da2e244a
commit c2b85ab68a
28 changed files with 200 additions and 194 deletions

View File

@@ -41,7 +41,8 @@ void Communicator::Init(Json const& config) {
#endif
break;
}
case CommunicatorType::kInMemory: {
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}

View File

@@ -34,9 +34,10 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
break;
case CommunicatorType::kFederated:
case CommunicatorType::kInMemory:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
break;
default:

View File

@@ -69,7 +69,7 @@ enum class Operation {
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory };
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
@@ -220,6 +220,8 @@ class Communicator {
result = CommunicatorType::kFederated;
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
result = CommunicatorType::kInMemory;
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
result = CommunicatorType::kInMemoryNccl;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}