merge latest changes
This commit is contained in:
81
src/collective/communicator-inl.cuh
Normal file
81
src/collective/communicator-inl.cuh
Normal file
@@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Reduce values from all processes and distribute the result back to all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronize device operations.
|
||||
* @param device ID of the device.
|
||||
*/
|
||||
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
7
src/collective/communicator-inl.hip.h
Normal file
7
src/collective/communicator-inl.hip.h
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
@@ -17,32 +17,15 @@ class DeviceCommunicator {
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
|
||||
@@ -23,20 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||
}
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#endif
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.reserve(size);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@@ -93,32 +101,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <collective::DataType data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
auto size = count * sizeof(T);
|
||||
host_buffer_.reserve(size);
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault));
|
||||
#else
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
|
||||
@@ -81,20 +81,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
|
||||
}
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
cuda_stream_));
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@@ -192,22 +190,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return id;
|
||||
}
|
||||
|
||||
template <ncclDataType_t data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result;
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||
nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
allreduce_calls_ += 1;
|
||||
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result;
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
|
||||
@@ -12,8 +12,7 @@
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator.h"
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
@@ -600,7 +599,6 @@ void SketchContainer::AllReduce() {
|
||||
}
|
||||
|
||||
timer_.Start(__func__);
|
||||
auto* communicator = collective::Communicator::GetDevice(device_);
|
||||
// Reduce the overhead on syncing.
|
||||
size_t global_sum_rows = num_rows_;
|
||||
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
|
||||
@@ -621,14 +619,15 @@ void SketchContainer::AllReduce() {
|
||||
auto offset = rank * d_columns_ptr.size();
|
||||
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
||||
gathered_ptrs.begin() + offset);
|
||||
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
|
||||
collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
|
||||
gathered_ptrs.size());
|
||||
|
||||
// Get the data from all workers.
|
||||
std::vector<size_t> recv_lengths;
|
||||
dh::caching_device_vector<char> recvbuf;
|
||||
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
|
||||
&recv_lengths, &recvbuf);
|
||||
communicator->Synchronize();
|
||||
collective::AllGatherV(device_, this->Current().data().get(),
|
||||
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
|
||||
collective::Synchronize(device_);
|
||||
|
||||
// Segment the received data.
|
||||
auto s_recvbuf = dh::ToSpan(recvbuf);
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/algorithm.cuh" // SegmentedArgSort
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
|
||||
@@ -231,8 +231,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
auto* communicator = collective::Communicator::GetDevice(device);
|
||||
communicator->AllReduceSum(results.data(), results.size());
|
||||
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/device_communicator.cuh" // DeviceCommunicator
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
|
||||
@@ -55,8 +55,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id);
|
||||
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
|
||||
@@ -42,12 +42,6 @@ struct GPUExpandEntry {
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bst_float GetLossChange() const {
|
||||
return split.loss_chg;
|
||||
}
|
||||
|
||||
@@ -26,12 +26,6 @@ struct ExpandEntryImpl {
|
||||
}
|
||||
[[nodiscard]] bst_node_t GetNodeId() const { return nid; }
|
||||
|
||||
static bool ChildIsValid(TrainParam const& param, bst_node_t depth, bst_node_t num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
|
||||
return static_cast<Impl const*>(this)->IsValidImpl(param, num_leaves);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/device_communicator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
|
||||
@@ -605,12 +605,13 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) {
|
||||
void AllReduceHist(int nidx, int num_histograms) {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx_->gpu_id, reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
}
|
||||
@@ -618,8 +619,7 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates,
|
||||
collective::DeviceCommunicator* communicator, const RegTree& tree) {
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
|
||||
if (candidates.empty()) return;
|
||||
// Some nodes we will manually compute histograms
|
||||
// others we will do by subtraction
|
||||
@@ -650,7 +650,7 @@ struct GPUHistMakerDevice {
|
||||
// Reduce all in one go
|
||||
// This gives much better latency in a distributed setting
|
||||
// when processing a large batch
|
||||
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size());
|
||||
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
|
||||
|
||||
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
||||
auto build_hist_nidx = hist_nidx.at(i);
|
||||
@@ -660,7 +660,7 @@ struct GPUHistMakerDevice {
|
||||
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, communicator, 1);
|
||||
this->AllReduceHist(subtraction_trick_nidx, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -718,7 +718,7 @@ struct GPUHistMakerDevice {
|
||||
parent.RightChild());
|
||||
}
|
||||
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto quantiser = *this->quantiser;
|
||||
@@ -735,7 +735,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, communicator, 1);
|
||||
this->AllReduceHist(kRootNIdx, 1);
|
||||
|
||||
// Remember root stats
|
||||
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
|
||||
@@ -751,7 +751,6 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
ObjInfo const* task, RegTree* p_tree,
|
||||
collective::DeviceCommunicator* communicator,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
@@ -762,7 +761,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("Reset");
|
||||
|
||||
monitor.Start("InitRoot");
|
||||
driver.Push({ this->InitRoot(p_tree, communicator) });
|
||||
driver.Push({this->InitRoot(p_tree)});
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
// The set of leaves that can be expanded asynchronously
|
||||
@@ -789,7 +788,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
this->BuildHistLeftRight(filtered_expand_set, communicator, tree);
|
||||
this->BuildHistLeftRight(filtered_expand_set, tree);
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
monitor.Start("EvaluateSplits");
|
||||
@@ -920,8 +919,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor_.Stop("InitData");
|
||||
|
||||
gpair->SetDevice(ctx_->gpu_id);
|
||||
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
|
||||
Reference in New Issue
Block a user