Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou 2023-06-13 12:53:03 -07:00 committed by GitHub
parent c2f0486d37
commit e70810be8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 190 additions and 106 deletions

View File

@ -0,0 +1,81 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Reduce values from all processes and distribute the result back to all processes.
* @param device ID of the device.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
template <Operation op>
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
}
/**
* @brief Synchronize device operations.
* @param device ID of the device.
*/
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
} // namespace collective
} // namespace xgboost

View File

@ -17,32 +17,15 @@ class DeviceCommunicator {
virtual ~DeviceCommunicator() = default; virtual ~DeviceCommunicator() = default;
/** /**
* @brief Sum values from all processes and distribute the result back to all processes. * @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data. * @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer. * @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/ */
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0; virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
/** /**
* @brief Gather variable-length values from all processes. * @brief Gather variable-length values from all processes.

View File

@ -23,20 +23,18 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
~DeviceCommunicatorAdapter() override = default; ~DeviceCommunicatorAdapter() override = default;
void AllReduceSum(float *send_receive_buffer, size_t count) override { void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count); Operation op) override {
} if (communicator_->GetWorldSize() == 1) {
return;
}
void AllReduceSum(double *send_receive_buffer, size_t count) override { dh::safe_cuda(cudaSetDevice(device_ordinal_));
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count); auto size = count * GetTypeSize(data_type);
} host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override { communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count); dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
} }
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments, void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
@ -77,20 +75,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
} }
private: private:
template <collective::DataType data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(T);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
int const device_ordinal_; int const device_ordinal_;
Communicator *communicator_; Communicator *communicator_;
/// Host buffer used to call communicator functions. /// Host buffer used to call communicator functions.

View File

@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
} }
} }
void AllReduceSum(float *send_receive_buffer, size_t count) override { void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
DoAllReduceSum<ncclFloat>(send_receive_buffer, count); Operation op) override {
} if (communicator_->GetWorldSize() == 1) {
return;
}
void AllReduceSum(double *send_receive_buffer, size_t count) override { dh::safe_cuda(cudaSetDevice(device_ordinal_));
DoAllReduceSum<ncclDouble>(send_receive_buffer, count); dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
} GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override { allreduce_bytes_ += count * GetTypeSize(data_type);
DoAllReduceSum<ncclInt64>(send_receive_buffer, count); allreduce_calls_ += 1;
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
} }
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments, void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id; return id;
} }
template <ncclDataType_t data_type, typename T> static ncclDataType_t GetNcclDataType(DataType const &data_type) {
void DoAllReduceSum(T *send_receive_buffer, size_t count) { ncclDataType_t result;
if (communicator_->GetWorldSize() == 1) { switch (data_type) {
return; case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
} }
return result;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_)); static ncclRedOp_t GetNcclRedOp(Operation const &op) {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum, ncclRedOp_t result;
nccl_comm_, cuda_stream_)); switch (op) {
allreduce_bytes_ += count * sizeof(T); case Operation::kMax:
allreduce_calls_ += 1; result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
LOG(FATAL) << "Not implemented yet.";
default:
LOG(FATAL) << "Unknown reduce operation.";
}
return result;
} }
int const device_ordinal_; int const device_ordinal_;

View File

@ -12,8 +12,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "../collective/communicator.h" #include "../collective/communicator-inl.cuh"
#include "../collective/device_communicator.cuh"
#include "categorical.h" #include "categorical.h"
#include "common.h" #include "common.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"
@ -510,7 +509,6 @@ void SketchContainer::AllReduce() {
} }
timer_.Start(__func__); timer_.Start(__func__);
auto* communicator = collective::Communicator::GetDevice(device_);
// Reduce the overhead on syncing. // Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_; size_t global_sum_rows = num_rows_;
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1); collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
@ -531,14 +529,15 @@ void SketchContainer::AllReduce() {
auto offset = rank * d_columns_ptr.size(); auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset); gathered_ptrs.begin() + offset);
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size()); collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
gathered_ptrs.size());
// Get the data from all workers. // Get the data from all workers.
std::vector<size_t> recv_lengths; std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf; dh::caching_device_vector<char> recvbuf;
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(), collective::AllGatherV(device_, this->Current().data().get(),
&recv_lengths, &recvbuf); dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
communicator->Synchronize(); collective::Synchronize(device_);
// Segment the received data. // Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf); auto s_recvbuf = dh::ToSpan(recvbuf);

View File

@ -11,7 +11,7 @@
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "../collective/device_communicator.cuh" #include "../collective/communicator-inl.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort #include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights #include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads #include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
@ -205,8 +205,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
if (collective::IsDistributed()) { if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice(); int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device); CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device); collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
communicator->AllReduceSum(results.data(), results.size());
} }
auto reduce_in = dh::MakeTransformIterator<Pair>( auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {

View File

@ -11,7 +11,7 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include "../collective/device_communicator.cuh" // DeviceCommunicator #include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator #include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h" #include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE #include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
@ -49,8 +49,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it, thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id); collective::AllReduce<collective::Operation::kSum>(
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2); ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable { [=] XGBOOST_DEVICE(std::size_t i) mutable {

View File

@ -12,7 +12,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../collective/device_communicator.cuh" #include "../collective/communicator-inl.cuh"
#include "../common/bitfield.h" #include "../common/bitfield.h"
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/cuda_context.cuh" // CUDAContext #include "../common/cuda_context.cuh" // CUDAContext
@ -546,12 +546,13 @@ struct GPUHistMakerDevice {
} }
// num histograms is the number of contiguous histograms in memory to reduce over // num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) { void AllReduceHist(int nidx, int num_histograms) {
monitor.Start("AllReduce"); monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data(); auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT; using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist), collective::AllReduce<collective::Operation::kSum>(
page->Cuts().TotalBins() * 2 * num_histograms); ctx_->gpu_id, reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms);
monitor.Stop("AllReduce"); monitor.Stop("AllReduce");
} }
@ -559,8 +560,7 @@ struct GPUHistMakerDevice {
/** /**
* \brief Build GPU local histograms for the left and right child of some parent node * \brief Build GPU local histograms for the left and right child of some parent node
*/ */
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
collective::DeviceCommunicator* communicator, const RegTree& tree) {
if (candidates.empty()) return; if (candidates.empty()) return;
// Some nodes we will manually compute histograms // Some nodes we will manually compute histograms
// others we will do by subtraction // others we will do by subtraction
@ -591,7 +591,7 @@ struct GPUHistMakerDevice {
// Reduce all in one go // Reduce all in one go
// This gives much better latency in a distributed setting // This gives much better latency in a distributed setting
// when processing a large batch // when processing a large batch
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size()); this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
for (size_t i = 0; i < subtraction_nidx.size(); i++) { for (size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i); auto build_hist_nidx = hist_nidx.at(i);
@ -601,7 +601,7 @@ struct GPUHistMakerDevice {
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram manually // Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx); this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, communicator, 1); this->AllReduceHist(subtraction_trick_nidx, 1);
} }
} }
} }
@ -659,7 +659,7 @@ struct GPUHistMakerDevice {
parent.RightChild()); parent.RightChild());
} }
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) { GPUExpandEntry InitRoot(RegTree* p_tree) {
constexpr bst_node_t kRootNIdx = 0; constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
auto quantiser = *this->quantiser; auto quantiser = *this->quantiser;
@ -676,7 +676,7 @@ struct GPUHistMakerDevice {
hist.AllocateHistograms({kRootNIdx}); hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx); this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, communicator, 1); this->AllReduceHist(kRootNIdx, 1);
// Remember root stats // Remember root stats
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised); auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
@ -692,7 +692,6 @@ struct GPUHistMakerDevice {
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
ObjInfo const* task, RegTree* p_tree, ObjInfo const* task, RegTree* p_tree,
collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree; auto& tree = *p_tree;
// Process maximum 32 nodes at a time // Process maximum 32 nodes at a time
@ -703,7 +702,7 @@ struct GPUHistMakerDevice {
monitor.Stop("Reset"); monitor.Stop("Reset");
monitor.Start("InitRoot"); monitor.Start("InitRoot");
driver.Push({ this->InitRoot(p_tree, communicator) }); driver.Push({this->InitRoot(p_tree)});
monitor.Stop("InitRoot"); monitor.Stop("InitRoot");
// The set of leaves that can be expanded asynchronously // The set of leaves that can be expanded asynchronously
@ -730,7 +729,7 @@ struct GPUHistMakerDevice {
monitor.Stop("UpdatePosition"); monitor.Stop("UpdatePosition");
monitor.Start("BuildHist"); monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, communicator, tree); this->BuildHistLeftRight(filtered_expand_set, tree);
monitor.Stop("BuildHist"); monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits"); monitor.Start("EvaluateSplits");
@ -851,8 +850,7 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Stop("InitData"); monitor_.Stop("InitData");
gpair->SetDevice(ctx_->gpu_id); gpair->SetDevice(ctx_->gpu_id);
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id); maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
} }
bool UpdatePredictionCache(const DMatrix* data, bool UpdatePredictionCache(const DMatrix* data,

View File

@ -8,6 +8,7 @@
#include <string> // for string #include <string> // for string
#include "../../../src/collective/nccl_device_communicator.cuh" #include "../../../src/collective/nccl_device_communicator.cuh"
#include "../../../src/collective/communicator-inl.cuh"
namespace xgboost { namespace xgboost {
namespace collective { namespace collective {

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "test_quantile.h" #include "test_quantile.h"
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/collective/device_communicator.cuh" #include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/common/hist_util.cuh" #include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh" #include "../../../src/common/quantile.cuh"
@ -464,10 +464,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
thrust::copy(thrust::device, local_data.data(), thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(), local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank); all_workers.begin() + local_data.size() * rank);
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device); collective::AllReduce<collective::Operation::kSum>(device, all_workers.data().get(),
all_workers.size());
communicator->AllReduceSum(all_workers.data().get(), all_workers.size()); collective::Synchronize(device);
communicator->Synchronize();
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size()); std::vector<float> h_base_line(base_line.size());

View File

@ -36,7 +36,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
int count = 3; int count = 3;
thrust::device_vector<double> buffer(count, 0); thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end()); thrust::sequence(buffer.begin(), buffer.end());
adapter.AllReduceSum(buffer.data().get(), count); adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum);
thrust::host_vector<double> host_buffer = buffer; thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count); EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) { for (auto i = 0; i < count; i++) {