Work around nvcc error. (#7673)

This commit is contained in:
Jiaming Yuan 2022-02-19 01:41:46 +08:00 committed by GitHub
parent 3877043d41
commit d625dc2047
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 17 deletions

View File

@ -81,7 +81,7 @@ namespace common {
#if defined(_MSC_VER) #if defined(_MSC_VER)
// Windows CUDA doesn't have __assert_fail. // Windows CUDA doesn't have __assert_fail.
#define KERNEL_CHECK(cond) \ #define CUDA_KERNEL_CHECK(cond) \
do { \ do { \
if (XGBOOST_EXPECT(!(cond), false)) { \ if (XGBOOST_EXPECT(!(cond), false)) { \
asm("trap;"); \ asm("trap;"); \
@ -92,14 +92,15 @@ namespace common {
#define __ASSERT_STR_HELPER(x) #x #define __ASSERT_STR_HELPER(x) #x
#define KERNEL_CHECK(cond) \ #define CUDA_KERNEL_CHECK(cond) \
(XGBOOST_EXPECT((cond), true) \ (XGBOOST_EXPECT((cond), true) \
? static_cast<void>(0) \ ? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \ : __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__))
__PRETTY_FUNCTION__))
#endif // defined(_MSC_VER) #endif // defined(_MSC_VER)
#define KERNEL_CHECK CUDA_KERNEL_CHECK
#define SPAN_CHECK KERNEL_CHECK #define SPAN_CHECK KERNEL_CHECK
#else // ------------------------------ not CUDA ---------------------------- #else // ------------------------------ not CUDA ----------------------------
@ -120,11 +121,7 @@ namespace common {
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
#if defined(__CUDA_ARCH__) #define SPAN_LT(lhs, rhs) SPAN_CHECK((lhs) < (rhs))
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
#else
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
#endif // defined(__CUDA_ARCH__)
namespace detail { namespace detail {
/*! /*!
@ -671,7 +668,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
Span<byte, detail::ExtentAsBytesValue<T, E>::value> { Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()}; return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -120,7 +120,8 @@ class RowPartitioner {
int64_t* d_left_count = left_counts_.data().get() + nidx; int64_t* d_left_count = left_counts_.data().get() + nidx;
// Launch 1 thread for each row // Launch 1 thread for each row
dh::LaunchN<1, 128>(segment.Size(), [=] __device__(size_t idx) { dh::LaunchN<1, 128>(segment.Size(), [segment, op, left_nidx, right_nidx, d_ridx, d_left_count,
d_position] __device__(size_t idx) {
// LaunchN starts from zero, so we restore the row index by adding segment.begin // LaunchN starts from zero, so we restore the row index by adding segment.begin
idx += segment.begin; idx += segment.begin;
RowIndexT ridx = d_ridx[idx]; RowIndexT ridx = d_ridx[idx];