Switch back to the GPUIDX macro (#9438)
This commit is contained in:
@@ -34,6 +34,12 @@
|
||||
#define DeclareUnifiedTest(name) name
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define GPUIDX (common::AllVisibleGPUs() == 1 ? 0 : collective::GetRank())
|
||||
#else
|
||||
#define GPUIDX (-1)
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define DeclareUnifiedDistributedTest(name) MGPU ## name
|
||||
#else
|
||||
@@ -540,15 +546,6 @@ void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int GetGPUId() {
|
||||
#if defined(__CUDACC__)
|
||||
auto const n_gpus = common::AllVisibleGPUs();
|
||||
return n_gpus == 1 ? 0 : collective::GetRank();
|
||||
#else
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
|
||||
class BaseMGPUTest : public ::testing::Test {
|
||||
protected:
|
||||
int world_size_;
|
||||
|
||||
Reference in New Issue
Block a user