merge latest changes

This commit is contained in:
amdsc21
2023-06-15 21:39:14 +02:00
18 changed files with 284 additions and 151 deletions

View File

@@ -2,11 +2,11 @@
#include "test_quantile.h"
#include "../helpers.h"
#if defined(XGBOOST_USE_CUDA)
#include "../../../src/collective/device_communicator.cuh"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../../../src/collective/device_communicator.hip.h"
#include "../../../src/collective/communicator-inl.hip.h"
#include "../../../src/common/hist_util.hip.h"
#include "../../../src/common/quantile.hip.h"
#endif
@@ -474,10 +474,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
communicator->Synchronize();
collective::AllReduce<collective::Operation::kSum>(device, all_workers.data().get(),
all_workers.size());
collective::Synchronize(device);
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());