Work around nvcc error. (#7673)
This commit is contained in:
parent
3877043d41
commit
d625dc2047
@ -81,7 +81,7 @@ namespace common {
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
// Windows CUDA doesn't have __assert_fail.
|
||||
#define KERNEL_CHECK(cond) \
|
||||
#define CUDA_KERNEL_CHECK(cond) \
|
||||
do { \
|
||||
if (XGBOOST_EXPECT(!(cond), false)) { \
|
||||
asm("trap;"); \
|
||||
@ -92,14 +92,15 @@ namespace common {
|
||||
|
||||
#define __ASSERT_STR_HELPER(x) #x
|
||||
|
||||
#define KERNEL_CHECK(cond) \
|
||||
#define CUDA_KERNEL_CHECK(cond) \
|
||||
(XGBOOST_EXPECT((cond), true) \
|
||||
? static_cast<void>(0) \
|
||||
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
|
||||
__PRETTY_FUNCTION__))
|
||||
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__))
|
||||
|
||||
#endif // defined(_MSC_VER)
|
||||
|
||||
#define KERNEL_CHECK CUDA_KERNEL_CHECK
|
||||
|
||||
#define SPAN_CHECK KERNEL_CHECK
|
||||
|
||||
#else // ------------------------------ not CUDA ----------------------------
|
||||
@ -120,11 +121,7 @@ namespace common {
|
||||
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
|
||||
#else
|
||||
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
|
||||
#endif // defined(__CUDA_ARCH__)
|
||||
#define SPAN_LT(lhs, rhs) SPAN_CHECK((lhs) < (rhs))
|
||||
|
||||
namespace detail {
|
||||
/*!
|
||||
@ -671,7 +668,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
|
||||
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
||||
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -120,7 +120,8 @@ class RowPartitioner {
|
||||
|
||||
int64_t* d_left_count = left_counts_.data().get() + nidx;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 128>(segment.Size(), [=] __device__(size_t idx) {
|
||||
dh::LaunchN<1, 128>(segment.Size(), [segment, op, left_nidx, right_nidx, d_ridx, d_left_count,
|
||||
d_position] __device__(size_t idx) {
|
||||
// LaunchN starts from zero, so we restore the row index by adding segment.begin
|
||||
idx += segment.begin;
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user