compiler errors fix

This commit is contained in:
Hui Liu 2024-01-12 12:09:01 -08:00
parent 1e1e8be3a5
commit 9759e28e6a
10 changed files with 90 additions and 82 deletions

View File

@ -17,6 +17,8 @@
#include "xgboost/learner.h"
#if defined(XGBOOST_USE_NCCL)
#include <nccl.h>
#elif defined(XGBOOST_USE_RCCL)
#include <rccl.h>
#endif
namespace xgboost {

View File

@ -1,15 +1,25 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0)
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include "nccl_stub.h"
#if defined(XGBOOST_USE_NCCL)
#include <cuda.h> // for CUDA_VERSION
#include <cuda_runtime_api.h> // for cudaPeekAtLastError
#include <dlfcn.h> // for dlclose, dlsym, dlopen
#include <nccl.h>
#include <thrust/system/cuda/error.h> // for cuda_category
#include <thrust/system_error.h> // for system_error
#elif defined(XGBOOST_USE_RCCL)
#include "../common/cuda_to_hip.h"
#include "../common/device_helpers.hip.h"
#include <hip/hip_runtime_api.h> // for cudaPeekAtLastError
#include <dlfcn.h> // for dlclose, dlsym, dlopen
#include <rccl.h>
#include <thrust/system/hip/error.h> // for cuda_category
#include <thrust/system_error.h> // for system_error
#endif
#include <cstdint> // for int32_t
#include <sstream> // for stringstream

View File

@ -2,10 +2,17 @@
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#if defined(XGBOOST_USE_NCCL) || (defined(XGBOOST_USE_RCCL) && 0)
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#if defined(XGBOOST_USE_NCCL)
#include <cuda_runtime_api.h>
#include <nccl.h>
#elif defined(XGBOOST_USE_RCCL)
#include "../common/cuda_to_hip.h"
#include "../common/device_helpers.cuh"
#include <hip/hip_runtime_api.h>
#include <rccl.h>
#endif
#include <string> // for string

View File

@ -226,6 +226,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
});
}
#if defined(XGBOOST_USE_CUDA)
template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
xgboost::common::Span<IdxT> sorted_idx) {
@ -295,5 +296,51 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}
#elif defined(XGBOOST_USE_HIP)
template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
xgboost::common::Span<IdxT> sorted_idx) {
std::size_t bytes = 0;
auto cuctx = ctx->CUDACtx();
dh::Iota(sorted_idx, cuctx->Stream());
using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;
dh::TemporaryArray<KeyT> out(keys.size());
dh::TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!dh::BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
dh::safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8, cuctx->Stream(), false)));
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
dh::safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8, cuctx->Stream(), false)));
} else {
void *d_temp_storage = nullptr;
dh::safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8, cuctx->Stream(), false)));
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
dh::safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8, cuctx->Stream(), false)));
}
dh::safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), hipMemcpyDeviceToDevice, cuctx->Stream()));
}
#endif
} // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_

View File

@ -40,10 +40,6 @@
#include "xgboost/logging.h"
#include "xgboost/span.h"
#ifdef XGBOOST_USE_RCCL
#include "rccl.h"
#endif // XGBOOST_USE_RCCL
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#include "rmm/mr/device/per_device_resource.hpp"
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
@ -98,30 +94,6 @@ XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT
}
namespace dh {
#ifdef XGBOOST_USE_RCCL
#define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__)
inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file, int line) {
if (code != ncclSuccess) {
std::stringstream ss;
ss << "RCCL failure: " << ncclGetErrorString(code) << ".";
ss << " " << file << "(" << line << ")\n";
if (code == ncclUnhandledCudaError) {
// nccl usually preserves the last error so we can get more details.
auto err = hipPeekAtLastError();
ss << " CUDA error: " << thrust::system_error(err, thrust::hip_category()).what() << "\n";
} else if (code == ncclSystemError) {
ss << " This might be caused by a network configuration issue. Please consider specifying "
"the network interface for RCCL via environment variables listed in its reference: "
"`https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html`.\n";
}
LOG(FATAL) << ss.str();
}
return code;
}
#endif
inline int32_t CudaGetPointerDevice(void const *ptr) {
int32_t device = -1;
hipPointerAttribute_t attr;
@ -298,8 +270,8 @@ inline void LaunchN(size_t n, L lambda) {
}
template <typename Container>
void Iota(Container array) {
LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; });
void Iota(Container array, cudaStream_t stream) {
LaunchN(array.size(), stream, [=] __device__(size_t i) { array[i] = i; });
}
namespace detail {
@ -465,7 +437,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
hipcub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
static hipcub::CachingDeviceAllocator *allocator = new hipcub::CachingDeviceAllocator(2, 9, 29);
thread_local std::unique_ptr<hipcub::CachingDeviceAllocator> allocator{
std::make_unique<hipcub::CachingDeviceAllocator>(2, 9, 29)};
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
@ -581,6 +554,16 @@ class DoubleBuffer {
T *Other() { return buff.Alternate(); }
};
template <typename T>
xgboost::common::Span<T> LazyResize(xgboost::Context const *ctx,
xgboost::HostDeviceVector<T> *buffer, std::size_t n) {
buffer->SetDevice(ctx->Device());
if (buffer->Size() < n) {
buffer->Resize(n);
}
return buffer->DeviceSpan().subspan(0, n);
}
/**
* \brief Copies device span to std::vector.
*
@ -1017,49 +1000,6 @@ void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items)
InclusiveScan(d_in, d_out, hipcub::Sum(), num_items);
}
template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_idx) {
size_t bytes = 0;
Iota(sorted_idx);
using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;
TemporaryArray<KeyT> out(keys.size());
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
} else {
void *d_temp_storage = nullptr;
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
}
safe_cuda(hipMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), hipMemcpyDeviceToDevice));
}
class CUDAStreamView;
class CUDAEvent {
@ -1105,6 +1045,8 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
dh::safe_cuda(hipEventRecord(event_, hipStream_t{stream}));
}
// Changing this has effect on prediction return, where we need to pass the pointer to
// third-party libraries like cuPy
inline CUDAStreamView DefaultStream() {
#ifdef HIP_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{hipStreamPerThread};

View File

@ -6,7 +6,7 @@
#include "../common/common.h" // for AssertGPUSupport
namespace xgboost {
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
void ArrayInterfaceHandler::SyncCudaStream(int64_t) { common::AssertGPUSupport(); }
bool ArrayInterfaceHandler::IsCudaPtr(void const *) { return false; }
#endif // !defined(XGBOOST_USE_CUDA)

View File

@ -9,7 +9,7 @@
#include <cstdint> // for int32_t
#include "../common/common.h" // for Range
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
#include "../common/linalg_op.cuh"
#endif
#include "../common/linalg_op.h"

View File

@ -1,2 +1,2 @@
#include "test_multiclass_obj.cc"
#include "test_multiclass_obj_gpu.cu"

View File

@ -193,7 +193,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"});
}
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
TEST(Objective, CPU_vs_CUDA) {
Context ctx = MakeCUDACtx(GPUIDX);
@ -271,7 +271,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) {
}
// CoxRegression not implemented in GPU code, no need for testing.
#if !defined(__CUDACC__)
#if !defined(__CUDACC__) && !defined(__HIP_PLATFORM_AMD__)
TEST(Objective, CoxRegressionGPair) {
Context ctx = MakeCUDACtx(GPUIDX);
std::vector<std::pair<std::string, std::string>> args;

View File

@ -1,2 +1,2 @@
#include "test_regression_obj.cc"
#include "test_regression_obj_gpu.cu"