Switch to per-thread default stream (#9396)
This commit is contained in:
parent
7a0ccfbb49
commit
f7f673b00c
@ -127,6 +127,7 @@ endfunction(format_gencode_flags flags)
|
|||||||
# Set CUDA related flags to target. Must be used after code `format_gencode_flags`.
|
# Set CUDA related flags to target. Must be used after code `format_gencode_flags`.
|
||||||
function(xgboost_set_cuda_flags target)
|
function(xgboost_set_cuda_flags target)
|
||||||
target_compile_options(${target} PRIVATE
|
target_compile_options(${target} PRIVATE
|
||||||
|
$<$<COMPILE_LANGUAGE:CUDA>:--default-stream per-thread>
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:${GEN_CODE}>
|
$<$<COMPILE_LANGUAGE:CUDA>:${GEN_CODE}>
|
||||||
|
|||||||
@ -44,16 +44,12 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
|||||||
nccl_unique_id_ = GetUniqueId();
|
nccl_unique_id_ = GetUniqueId();
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||||
if (world_size_ == 1) {
|
if (world_size_ == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (cuda_stream_) {
|
|
||||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
|
||||||
}
|
|
||||||
if (nccl_comm_) {
|
if (nccl_comm_) {
|
||||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||||
}
|
}
|
||||||
@ -123,8 +119,8 @@ ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
|||||||
|
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||||
std::size_t size, cudaStream_t stream) {
|
std::size_t size) {
|
||||||
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
dh::LaunchN(size, [=] __device__(std::size_t idx) {
|
||||||
auto result = device_buffer[idx];
|
auto result = device_buffer[idx];
|
||||||
for (auto rank = 1; rank < world_size; rank++) {
|
for (auto rank = 1; rank < world_size; rank++) {
|
||||||
result = func(result, device_buffer[rank * size + idx]);
|
result = func(result, device_buffer[rank * size + idx]);
|
||||||
@ -142,25 +138,22 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
|||||||
|
|
||||||
// First gather data from all the workers.
|
// First gather data from all the workers.
|
||||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||||
nccl_comm_, cuda_stream_));
|
nccl_comm_, dh::DefaultStream()));
|
||||||
if (needs_sync_) {
|
if (needs_sync_) {
|
||||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
dh::DefaultStream().Sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then reduce locally.
|
// Then reduce locally.
|
||||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case Operation::kBitwiseAND:
|
case Operation::kBitwiseAND:
|
||||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
|
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
|
||||||
cuda_stream_);
|
|
||||||
break;
|
break;
|
||||||
case Operation::kBitwiseOR:
|
case Operation::kBitwiseOR:
|
||||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
|
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
|
||||||
cuda_stream_);
|
|
||||||
break;
|
break;
|
||||||
case Operation::kBitwiseXOR:
|
case Operation::kBitwiseXOR:
|
||||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
|
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
|
||||||
cuda_stream_);
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||||
@ -179,7 +172,7 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
|||||||
} else {
|
} else {
|
||||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||||
cuda_stream_));
|
dh::DefaultStream()));
|
||||||
}
|
}
|
||||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||||
allreduce_calls_ += 1;
|
allreduce_calls_ += 1;
|
||||||
@ -206,7 +199,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
|||||||
for (int32_t i = 0; i < world_size_; ++i) {
|
for (int32_t i = 0; i < world_size_; ++i) {
|
||||||
size_t as_bytes = segments->at(i);
|
size_t as_bytes = segments->at(i);
|
||||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||||
ncclChar, i, nccl_comm_, cuda_stream_));
|
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
||||||
offset += as_bytes;
|
offset += as_bytes;
|
||||||
}
|
}
|
||||||
dh::safe_nccl(ncclGroupEnd());
|
dh::safe_nccl(ncclGroupEnd());
|
||||||
@ -217,7 +210,7 @@ void NcclDeviceCommunicator::Synchronize() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
dh::DefaultStream().Sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
|
|||||||
@ -77,7 +77,6 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
int const world_size_;
|
int const world_size_;
|
||||||
int const rank_;
|
int const rank_;
|
||||||
ncclComm_t nccl_comm_{};
|
ncclComm_t nccl_comm_{};
|
||||||
cudaStream_t cuda_stream_{};
|
|
||||||
ncclUniqueId nccl_unique_id_{};
|
ncclUniqueId nccl_unique_id_{};
|
||||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||||
|
|||||||
@ -1176,7 +1176,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
|||||||
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
|
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; }
|
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; }
|
||||||
|
|
||||||
class CUDAStream {
|
class CUDAStream {
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
|
|||||||
@ -135,12 +135,12 @@ void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan<BatchIt> batch_iter
|
|||||||
CHECK(!force_use_u64);
|
CHECK(!force_use_u64);
|
||||||
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
|
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
|
||||||
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
||||||
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
|
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
|
||||||
kernel, batch_iter, is_valid, out_column_size);
|
kernel, batch_iter, is_valid, out_column_size);
|
||||||
} else {
|
} else {
|
||||||
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
|
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
|
||||||
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
||||||
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
|
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
|
||||||
kernel, batch_iter, is_valid, out_column_size);
|
kernel, batch_iter, is_valid, out_column_size);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -18,12 +18,10 @@ RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
|||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||||
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
|
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
|
||||||
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
|
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
|
||||||
dh::safe_cuda(cudaStreamCreate(&stream_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPartitioner::~RowPartitioner() {
|
RowPartitioner::~RowPartitioner() {
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
||||||
|
|||||||
@ -116,7 +116,7 @@ template <typename RowIndexT, typename OpT, typename OpDataT>
|
|||||||
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||||
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
||||||
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
|
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
|
||||||
dh::device_vector<int8_t>* tmp, cudaStream_t stream) {
|
dh::device_vector<int8_t>* tmp) {
|
||||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
||||||
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
||||||
d_counts.data()};
|
d_counts.data()};
|
||||||
@ -135,12 +135,12 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
|||||||
size_t temp_bytes = 0;
|
size_t temp_bytes = 0;
|
||||||
if (tmp->empty()) {
|
if (tmp->empty()) {
|
||||||
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
|
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
|
||||||
IndexFlagOp(), total_rows, stream);
|
IndexFlagOp(), total_rows);
|
||||||
tmp->resize(temp_bytes);
|
tmp->resize(temp_bytes);
|
||||||
}
|
}
|
||||||
temp_bytes = tmp->size();
|
temp_bytes = tmp->size();
|
||||||
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
|
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
|
||||||
discard_write_iterator, IndexFlagOp(), total_rows, stream);
|
discard_write_iterator, IndexFlagOp(), total_rows);
|
||||||
|
|
||||||
constexpr int kBlockSize = 256;
|
constexpr int kBlockSize = 256;
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
|||||||
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
|
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
|
||||||
|
|
||||||
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
|
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
|
||||||
<<<grid_size, kBlockSize, 0, stream>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
|
<<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NodePositionInfo {
|
struct NodePositionInfo {
|
||||||
@ -221,7 +221,6 @@ class RowPartitioner {
|
|||||||
dh::device_vector<int8_t> tmp_;
|
dh::device_vector<int8_t> tmp_;
|
||||||
dh::PinnedMemory pinned_;
|
dh::PinnedMemory pinned_;
|
||||||
dh::PinnedMemory pinned2_;
|
dh::PinnedMemory pinned2_;
|
||||||
cudaStream_t stream_;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RowPartitioner(int device_idx, size_t num_rows);
|
RowPartitioner(int device_idx, size_t num_rows);
|
||||||
@ -278,7 +277,7 @@ class RowPartitioner {
|
|||||||
}
|
}
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
|
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
|
||||||
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
|
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
|
||||||
cudaMemcpyDefault, stream_));
|
cudaMemcpyDefault));
|
||||||
|
|
||||||
// Temporary arrays
|
// Temporary arrays
|
||||||
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
|
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
|
||||||
@ -287,12 +286,12 @@ class RowPartitioner {
|
|||||||
// Partition the rows according to the operator
|
// Partition the rows according to the operator
|
||||||
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
||||||
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
|
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
|
||||||
total_rows, op, &tmp_, stream_);
|
total_rows, op, &tmp_);
|
||||||
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
|
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
|
||||||
cudaMemcpyDefault, stream_));
|
cudaMemcpyDefault));
|
||||||
// TODO(Rory): this synchronisation hurts performance a lot
|
// TODO(Rory): this synchronisation hurts performance a lot
|
||||||
// Future optimisation should find a way to skip this
|
// Future optimisation should find a way to skip this
|
||||||
dh::safe_cuda(cudaStreamSynchronize(stream_));
|
dh::DefaultStream().Sync();
|
||||||
|
|
||||||
// Update segments
|
// Update segments
|
||||||
for (size_t i = 0; i < nidx.size(); i++) {
|
for (size_t i = 0; i < nidx.size(); i++) {
|
||||||
@ -327,13 +326,13 @@ class RowPartitioner {
|
|||||||
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
||||||
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
||||||
cudaMemcpyDefault, stream_));
|
cudaMemcpyDefault));
|
||||||
|
|
||||||
constexpr int kBlockSize = 512;
|
constexpr int kBlockSize = 512;
|
||||||
const int kItemsThread = 8;
|
const int kItemsThread = 8;
|
||||||
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
|
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
|
||||||
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size());
|
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size());
|
||||||
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0, stream_>>>(
|
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>(
|
||||||
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
|
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -73,7 +73,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
|
|||||||
dh::device_vector<int8_t> tmp;
|
dh::device_vector<int8_t> tmp;
|
||||||
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
|
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
|
||||||
dh::ToSpan(ridx_tmp), dh::ToSpan(counts),
|
dh::ToSpan(ridx_tmp), dh::ToSpan(counts),
|
||||||
total_rows, op, &tmp, nullptr);
|
total_rows, op, &tmp);
|
||||||
|
|
||||||
auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
|
auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
|
||||||
for (size_t i = 0; i < segments.size(); i++) {
|
for (size_t i = 0; i < segments.size(); i++) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user