Retry switching to per-thread default stream (#9416)

This commit is contained in:
Rong Ou 2023-07-25 16:09:12 -07:00 committed by GitHub
parent 54579da4d7
commit 7579905e18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 37 additions and 36 deletions

View File

@ -50,6 +50,7 @@ option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF) option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA ## CUDA
option(USE_CUDA "Build with GPU acceleration" OFF) option(USE_CUDA "Build with GPU acceleration" OFF)
option(USE_PER_THREAD_DEFAULT_STREAM "Build with per-thread default stream" ON)
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF) option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
option(BUILD_WITH_SHARED_NCCL "Build with shared NCCL library." OFF) option(BUILD_WITH_SHARED_NCCL "Build with shared NCCL library." OFF)
set(GPU_COMPUTE_VER "" CACHE STRING set(GPU_COMPUTE_VER "" CACHE STRING

View File

@ -133,6 +133,11 @@ function(xgboost_set_cuda_flags target)
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${OpenMP_CXX_FLAGS}> $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${OpenMP_CXX_FLAGS}>
$<$<COMPILE_LANGUAGE:CUDA>:-Xfatbin=-compress-all>) $<$<COMPILE_LANGUAGE:CUDA>:-Xfatbin=-compress-all>)
if (USE_PER_THREAD_DEFAULT_STREAM)
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--default-stream per-thread>)
endif (USE_PER_THREAD_DEFAULT_STREAM)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18") if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")
set_property(TARGET ${target} PROPERTY CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) set_property(TARGET ${target} PROPERTY CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18") endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")

View File

@ -44,16 +44,12 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
nccl_unique_id_ = GetUniqueId(); nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_)); dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
} }
NcclDeviceCommunicator::~NcclDeviceCommunicator() { NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) { if (world_size_ == 1) {
return; return;
} }
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) { if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_)); dh::safe_nccl(ncclCommDestroy(nccl_comm_));
} }
@ -123,8 +119,8 @@ ncclRedOp_t GetNcclRedOp(Operation const &op) {
template <typename Func> template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size, void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size, cudaStream_t stream) { std::size_t size) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) { dh::LaunchN(size, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx]; auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) { for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]); result = func(result, device_buffer[rank * size + idx]);
@ -142,25 +138,22 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
// First gather data from all the workers. // First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, cuda_stream_)); nccl_comm_, dh::DefaultStream()));
if (needs_sync_) { if (needs_sync_) {
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); dh::DefaultStream().Sync();
} }
// Then reduce locally. // Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer); auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) { switch (op) {
case Operation::kBitwiseAND: case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size, RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
cuda_stream_);
break; break;
case Operation::kBitwiseOR: case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size, RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
cuda_stream_);
break; break;
case Operation::kBitwiseXOR: case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size, RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
cuda_stream_);
break; break;
default: default:
LOG(FATAL) << "Not a bitwise reduce operation."; LOG(FATAL) << "Not a bitwise reduce operation.";
@ -179,7 +172,7 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
} else { } else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_)); dh::DefaultStream()));
} }
allreduce_bytes_ += count * GetTypeSize(data_type); allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1; allreduce_calls_ += 1;
@ -206,7 +199,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
for (int32_t i = 0; i < world_size_; ++i) { for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i); size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_)); ncclChar, i, nccl_comm_, dh::DefaultStream()));
offset += as_bytes; offset += as_bytes;
} }
dh::safe_nccl(ncclGroupEnd()); dh::safe_nccl(ncclGroupEnd());
@ -217,7 +210,7 @@ void NcclDeviceCommunicator::Synchronize() {
return; return;
} }
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); dh::DefaultStream().Sync();
} }
} // namespace collective } // namespace collective

View File

@ -77,7 +77,6 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int const world_size_; int const world_size_;
int const rank_; int const rank_;
ncclComm_t nccl_comm_{}; ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{}; ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.

View File

@ -480,7 +480,7 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() { cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
// Configure allocator with maximum cached bin size of ~1GB and no limit on // Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes // maximum cached bytes
static cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29); thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
return *allocator; return *allocator;
} }
pointer allocate(size_t n) { // NOLINT pointer allocate(size_t n) { // NOLINT
@ -1176,7 +1176,13 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream})); dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
} }
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; } inline CUDAStreamView DefaultStream() {
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{cudaStreamPerThread};
#else
return CUDAStreamView{cudaStreamLegacy};
#endif
}
class CUDAStream { class CUDAStream {
cudaStream_t stream_; cudaStream_t stream_;

View File

@ -134,12 +134,12 @@ void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan<BatchIt> batch_iter
CHECK(!force_use_u64); CHECK(!force_use_u64);
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>; auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory); auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}( dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
kernel, batch_iter, is_valid, out_column_size); kernel, batch_iter, is_valid, out_column_size);
} else { } else {
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>; auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory); auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}( dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
kernel, batch_iter, is_valid, out_column_size); kernel, batch_iter, is_valid, out_column_size);
} }
} else { } else {

View File

@ -18,12 +18,10 @@ RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
dh::safe_cuda(cudaSetDevice(device_idx_)); dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
dh::safe_cuda(cudaStreamCreate(&stream_));
} }
RowPartitioner::~RowPartitioner() { RowPartitioner::~RowPartitioner() {
dh::safe_cuda(cudaSetDevice(device_idx_)); dh::safe_cuda(cudaSetDevice(device_idx_));
dh::safe_cuda(cudaStreamDestroy(stream_));
} }
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) { common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {

View File

@ -116,7 +116,7 @@ template <typename RowIndexT, typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info, void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp, common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op, common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
dh::device_vector<int8_t>* tmp, cudaStream_t stream) { dh::device_vector<int8_t>* tmp) {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data()); dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
d_counts.data()}; d_counts.data()};
@ -135,12 +135,12 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
size_t temp_bytes = 0; size_t temp_bytes = 0;
if (tmp->empty()) { if (tmp->empty()) {
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator, cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
IndexFlagOp(), total_rows, stream); IndexFlagOp(), total_rows);
tmp->resize(temp_bytes); tmp->resize(temp_bytes);
} }
temp_bytes = tmp->size(); temp_bytes = tmp->size();
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator, cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows, stream); discard_write_iterator, IndexFlagOp(), total_rows);
constexpr int kBlockSize = 256; constexpr int kBlockSize = 256;
@ -149,7 +149,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT> SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
<<<grid_size, kBlockSize, 0, stream>>>(batch_info_itr, ridx, ridx_tmp, total_rows); <<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
} }
struct NodePositionInfo { struct NodePositionInfo {
@ -221,7 +221,6 @@ class RowPartitioner {
dh::device_vector<int8_t> tmp_; dh::device_vector<int8_t> tmp_;
dh::PinnedMemory pinned_; dh::PinnedMemory pinned_;
dh::PinnedMemory pinned2_; dh::PinnedMemory pinned2_;
cudaStream_t stream_;
public: public:
RowPartitioner(int device_idx, size_t num_rows); RowPartitioner(int device_idx, size_t num_rows);
@ -278,7 +277,7 @@ class RowPartitioner {
} }
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<OpDataT>), h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
cudaMemcpyDefault, stream_)); cudaMemcpyDefault));
// Temporary arrays // Temporary arrays
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0); auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
@ -287,12 +286,12 @@ class RowPartitioner {
// Partition the rows according to the operator // Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>( SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
total_rows, op, &tmp_, stream_); total_rows, op, &tmp_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
cudaMemcpyDefault, stream_)); cudaMemcpyDefault));
// TODO(Rory): this synchronisation hurts performance a lot // TODO(Rory): this synchronisation hurts performance a lot
// Future optimisation should find a way to skip this // Future optimisation should find a way to skip this
dh::safe_cuda(cudaStreamSynchronize(stream_)); dh::DefaultStream().Sync();
// Update segments // Update segments
for (size_t i = 0; i < nidx.size(); i++) { for (size_t i = 0; i < nidx.size(); i++) {
@ -327,13 +326,13 @@ class RowPartitioner {
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size()); dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(), dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
sizeof(NodePositionInfo) * ridx_segments_.size(), sizeof(NodePositionInfo) * ridx_segments_.size(),
cudaMemcpyDefault, stream_)); cudaMemcpyDefault));
constexpr int kBlockSize = 512; constexpr int kBlockSize = 512;
const int kItemsThread = 8; const int kItemsThread = 8;
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size()); common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size());
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0, stream_>>>( FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>(
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
} }
}; };

View File

@ -73,7 +73,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
dh::device_vector<int8_t> tmp; dh::device_vector<int8_t> tmp;
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
dh::ToSpan(ridx_tmp), dh::ToSpan(counts), dh::ToSpan(ridx_tmp), dh::ToSpan(counts),
total_rows, op, &tmp, nullptr); total_rows, op, &tmp);
auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; }; auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
for (size_t i = 0; i < segments.size(); i++) { for (size_t i = 0; i < segments.size(); i++) {