Refactor device communicator to make allreduce more flexible (#9295)
This commit is contained in:
parent
c2f0486d37
commit
e70810be8a
81
src/collective/communicator-inl.cuh
Normal file
81
src/collective/communicator-inl.cuh
Normal file
@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Reduce values from all processes and distribute the result back to all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronize device operations.
|
||||
* @param device ID of the device.
|
||||
*/
|
||||
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -17,32 +17,15 @@ class DeviceCommunicator {
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
|
||||
@ -23,20 +23,18 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||
}
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@ -77,20 +75,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <collective::DataType data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(T);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
|
||||
@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
|
||||
}
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
cuda_stream_));
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return id;
|
||||
}
|
||||
|
||||
template <ncclDataType_t data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result;
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||
nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
allreduce_calls_ += 1;
|
||||
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result;
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
|
||||
@ -12,8 +12,7 @@
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator.h"
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
@ -510,7 +509,6 @@ void SketchContainer::AllReduce() {
|
||||
}
|
||||
|
||||
timer_.Start(__func__);
|
||||
auto* communicator = collective::Communicator::GetDevice(device_);
|
||||
// Reduce the overhead on syncing.
|
||||
size_t global_sum_rows = num_rows_;
|
||||
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
|
||||
@ -531,14 +529,15 @@ void SketchContainer::AllReduce() {
|
||||
auto offset = rank * d_columns_ptr.size();
|
||||
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
||||
gathered_ptrs.begin() + offset);
|
||||
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
|
||||
collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
|
||||
gathered_ptrs.size());
|
||||
|
||||
// Get the data from all workers.
|
||||
std::vector<size_t> recv_lengths;
|
||||
dh::caching_device_vector<char> recvbuf;
|
||||
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
|
||||
&recv_lengths, &recvbuf);
|
||||
communicator->Synchronize();
|
||||
collective::AllGatherV(device_, this->Current().data().get(),
|
||||
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
|
||||
collective::Synchronize(device_);
|
||||
|
||||
// Segment the received data.
|
||||
auto s_recvbuf = dh::ToSpan(recvbuf);
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/algorithm.cuh" // SegmentedArgSort
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
|
||||
@ -205,8 +205,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
auto* communicator = collective::Communicator::GetDevice(device);
|
||||
communicator->AllReduceSum(results.data(), results.size());
|
||||
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/device_communicator.cuh" // DeviceCommunicator
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
|
||||
@ -49,8 +49,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id);
|
||||
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
@ -546,12 +546,13 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) {
|
||||
void AllReduceHist(int nidx, int num_histograms) {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx_->gpu_id, reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
}
|
||||
@ -559,8 +560,7 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates,
|
||||
collective::DeviceCommunicator* communicator, const RegTree& tree) {
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
|
||||
if (candidates.empty()) return;
|
||||
// Some nodes we will manually compute histograms
|
||||
// others we will do by subtraction
|
||||
@ -591,7 +591,7 @@ struct GPUHistMakerDevice {
|
||||
// Reduce all in one go
|
||||
// This gives much better latency in a distributed setting
|
||||
// when processing a large batch
|
||||
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size());
|
||||
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
|
||||
|
||||
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
||||
auto build_hist_nidx = hist_nidx.at(i);
|
||||
@ -601,7 +601,7 @@ struct GPUHistMakerDevice {
|
||||
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, communicator, 1);
|
||||
this->AllReduceHist(subtraction_trick_nidx, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -659,7 +659,7 @@ struct GPUHistMakerDevice {
|
||||
parent.RightChild());
|
||||
}
|
||||
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto quantiser = *this->quantiser;
|
||||
@ -676,7 +676,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, communicator, 1);
|
||||
this->AllReduceHist(kRootNIdx, 1);
|
||||
|
||||
// Remember root stats
|
||||
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
|
||||
@ -692,7 +692,6 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
ObjInfo const* task, RegTree* p_tree,
|
||||
collective::DeviceCommunicator* communicator,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
@ -703,7 +702,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("Reset");
|
||||
|
||||
monitor.Start("InitRoot");
|
||||
driver.Push({ this->InitRoot(p_tree, communicator) });
|
||||
driver.Push({this->InitRoot(p_tree)});
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
// The set of leaves that can be expanded asynchronously
|
||||
@ -730,7 +729,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
this->BuildHistLeftRight(filtered_expand_set, communicator, tree);
|
||||
this->BuildHistLeftRight(filtered_expand_set, tree);
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
monitor.Start("EvaluateSplits");
|
||||
@ -851,8 +850,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor_.Stop("InitData");
|
||||
|
||||
gpair->SetDevice(ctx_->gpu_id);
|
||||
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../src/collective/nccl_device_communicator.cuh"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_quantile.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/collective/device_communicator.cuh"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/common/hist_util.cuh"
|
||||
#include "../../../src/common/quantile.cuh"
|
||||
|
||||
@ -464,10 +464,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
||||
thrust::copy(thrust::device, local_data.data(),
|
||||
local_data.data() + local_data.size(),
|
||||
all_workers.begin() + local_data.size() * rank);
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device);
|
||||
|
||||
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
|
||||
communicator->Synchronize();
|
||||
collective::AllReduce<collective::Operation::kSum>(device, all_workers.data().get(),
|
||||
all_workers.size());
|
||||
collective::Synchronize(device);
|
||||
|
||||
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
|
||||
std::vector<float> h_base_line(base_line.size());
|
||||
|
||||
@ -36,7 +36,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
||||
int count = 3;
|
||||
thrust::device_vector<double> buffer(count, 0);
|
||||
thrust::sequence(buffer.begin(), buffer.end());
|
||||
adapter.AllReduceSum(buffer.data().get(), count);
|
||||
adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum);
|
||||
thrust::host_vector<double> host_buffer = buffer;
|
||||
EXPECT_EQ(host_buffer.size(), count);
|
||||
for (auto i = 0; i < count; i++) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user