Work around nvcc error. (#7673)
This commit is contained in:
parent
3877043d41
commit
d625dc2047
@ -81,25 +81,26 @@ namespace common {
|
|||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
|
|
||||||
// Windows CUDA doesn't have __assert_fail.
|
// Windows CUDA doesn't have __assert_fail.
|
||||||
#define KERNEL_CHECK(cond) \
|
#define CUDA_KERNEL_CHECK(cond) \
|
||||||
do { \
|
do { \
|
||||||
if (XGBOOST_EXPECT(!(cond), false)) { \
|
if (XGBOOST_EXPECT(!(cond), false)) { \
|
||||||
asm("trap;"); \
|
asm("trap;"); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#else // defined(_MSC_VER)
|
#else // defined(_MSC_VER)
|
||||||
|
|
||||||
#define __ASSERT_STR_HELPER(x) #x
|
#define __ASSERT_STR_HELPER(x) #x
|
||||||
|
|
||||||
#define KERNEL_CHECK(cond) \
|
#define CUDA_KERNEL_CHECK(cond) \
|
||||||
(XGBOOST_EXPECT((cond), true) \
|
(XGBOOST_EXPECT((cond), true) \
|
||||||
? static_cast<void>(0) \
|
? static_cast<void>(0) \
|
||||||
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
|
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__))
|
||||||
__PRETTY_FUNCTION__))
|
|
||||||
|
|
||||||
#endif // defined(_MSC_VER)
|
#endif // defined(_MSC_VER)
|
||||||
|
|
||||||
|
#define KERNEL_CHECK CUDA_KERNEL_CHECK
|
||||||
|
|
||||||
#define SPAN_CHECK KERNEL_CHECK
|
#define SPAN_CHECK KERNEL_CHECK
|
||||||
|
|
||||||
#else // ------------------------------ not CUDA ----------------------------
|
#else // ------------------------------ not CUDA ----------------------------
|
||||||
@ -120,11 +121,7 @@ namespace common {
|
|||||||
|
|
||||||
#endif // __CUDA_ARCH__
|
#endif // __CUDA_ARCH__
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#define SPAN_LT(lhs, rhs) SPAN_CHECK((lhs) < (rhs))
|
||||||
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
|
|
||||||
#else
|
|
||||||
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
|
|
||||||
#endif // defined(__CUDA_ARCH__)
|
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
/*!
|
/*!
|
||||||
@ -671,7 +668,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
|
|||||||
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
||||||
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -120,7 +120,8 @@ class RowPartitioner {
|
|||||||
|
|
||||||
int64_t* d_left_count = left_counts_.data().get() + nidx;
|
int64_t* d_left_count = left_counts_.data().get() + nidx;
|
||||||
// Launch 1 thread for each row
|
// Launch 1 thread for each row
|
||||||
dh::LaunchN<1, 128>(segment.Size(), [=] __device__(size_t idx) {
|
dh::LaunchN<1, 128>(segment.Size(), [segment, op, left_nidx, right_nidx, d_ridx, d_left_count,
|
||||||
|
d_position] __device__(size_t idx) {
|
||||||
// LaunchN starts from zero, so we restore the row index by adding segment.begin
|
// LaunchN starts from zero, so we restore the row index by adding segment.begin
|
||||||
idx += segment.begin;
|
idx += segment.begin;
|
||||||
RowIndexT ridx = d_ridx[idx];
|
RowIndexT ridx = d_ridx[idx];
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user