Work around nvcc error. (#7673)

This commit is contained in:
Jiaming Yuan 2022-02-19 01:41:46 +08:00 committed by GitHub
parent 3877043d41
commit d625dc2047
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 17 deletions

View File

@ -81,25 +81,26 @@ namespace common {
#if defined(_MSC_VER)
// Windows CUDA doesn't have __assert_fail.
#define KERNEL_CHECK(cond) \
do { \
if (XGBOOST_EXPECT(!(cond), false)) { \
asm("trap;"); \
} \
#define CUDA_KERNEL_CHECK(cond) \
do { \
if (XGBOOST_EXPECT(!(cond), false)) { \
asm("trap;"); \
} \
} while (0)
#else // defined(_MSC_VER)
#define __ASSERT_STR_HELPER(x) #x
#define KERNEL_CHECK(cond) \
(XGBOOST_EXPECT((cond), true) \
? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
__PRETTY_FUNCTION__))
#define CUDA_KERNEL_CHECK(cond) \
(XGBOOST_EXPECT((cond), true) \
? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__))
#endif // defined(_MSC_VER)
#define KERNEL_CHECK CUDA_KERNEL_CHECK
#define SPAN_CHECK KERNEL_CHECK
#else // ------------------------------ not CUDA ----------------------------
@ -120,11 +121,7 @@ namespace common {
#endif // __CUDA_ARCH__
#if defined(__CUDA_ARCH__)
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
#else
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
#endif // defined(__CUDA_ARCH__)
#define SPAN_LT(lhs, rhs) SPAN_CHECK((lhs) < (rhs))
namespace detail {
/*!
@ -671,7 +668,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
}
} // namespace common
} // namespace xgboost

View File

@ -120,7 +120,8 @@ class RowPartitioner {
int64_t* d_left_count = left_counts_.data().get() + nidx;
// Launch 1 thread for each row
dh::LaunchN<1, 128>(segment.Size(), [=] __device__(size_t idx) {
dh::LaunchN<1, 128>(segment.Size(), [segment, op, left_nidx, right_nidx, d_ridx, d_left_count,
d_position] __device__(size_t idx) {
// LaunchN starts from zero, so we restore the row index by adding segment.begin
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];