Switch back to the GPUIDX macro (#9438)
This commit is contained in:
@@ -351,7 +351,7 @@ void TestAllReduceBasic() {
|
||||
auto const world = collective::GetWorldSize();
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||
auto const device = GetGPUId();
|
||||
auto const device = GPUIDX;
|
||||
|
||||
// Set up single node version;
|
||||
HostDeviceVector<FeatureType> ft({}, device);
|
||||
@@ -440,7 +440,7 @@ void TestColumnSplitBasic() {
|
||||
}()};
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto ctx = MakeCUDACtx(GetGPUId());
|
||||
auto ctx = MakeCUDACtx(GPUIDX);
|
||||
HistogramCuts distributed_cuts = common::DeviceSketch(&ctx, m.get(), kBins);
|
||||
|
||||
// Generate cuts for single node environment
|
||||
@@ -483,7 +483,7 @@ void TestSameOnAllWorkers() {
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||
MetaInfo const &info) {
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GetGPUId();
|
||||
auto const device = GPUIDX;
|
||||
HostDeviceVector<FeatureType> ft({}, device);
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, device);
|
||||
HostDeviceVector<float> storage({}, device);
|
||||
|
||||
Reference in New Issue
Block a user