merge latest changes

This commit is contained in:
amdsc21
2023-06-15 21:39:14 +02:00
18 changed files with 284 additions and 151 deletions

View File

@@ -17,7 +17,7 @@
#include <tuple>
#include <utility>
#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
@@ -231,8 +231,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(results.data(), results.size());
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {