Work around nvcc error. (#7673)

This commit is contained in:
Jiaming Yuan
2022-02-19 01:41:46 +08:00
committed by GitHub
parent 3877043d41
commit d625dc2047
2 changed files with 14 additions and 17 deletions

View File

@@ -81,25 +81,26 @@ namespace common {
#if defined(_MSC_VER)
// Windows CUDA doesn't have __assert_fail.
#define KERNEL_CHECK(cond) \
do { \
if (XGBOOST_EXPECT(!(cond), false)) { \
asm("trap;"); \
} \
#define CUDA_KERNEL_CHECK(cond) \
do { \
if (XGBOOST_EXPECT(!(cond), false)) { \
asm("trap;"); \
} \
} while (0)
#else // defined(_MSC_VER)
#define __ASSERT_STR_HELPER(x) #x
#define KERNEL_CHECK(cond) \
(XGBOOST_EXPECT((cond), true) \
? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, \
__PRETTY_FUNCTION__))
#define CUDA_KERNEL_CHECK(cond) \
(XGBOOST_EXPECT((cond), true) \
? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__))
#endif // defined(_MSC_VER)
#define KERNEL_CHECK CUDA_KERNEL_CHECK
#define SPAN_CHECK KERNEL_CHECK
#else // ------------------------------ not CUDA ----------------------------
@@ -120,11 +121,7 @@ namespace common {
#endif // __CUDA_ARCH__
#if defined(__CUDA_ARCH__)
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
#else
#define SPAN_LT(lhs, rhs) KERNEL_CHECK((lhs) < (rhs))
#endif // defined(__CUDA_ARCH__)
#define SPAN_LT(lhs, rhs) SPAN_CHECK((lhs) < (rhs))
namespace detail {
/*!
@@ -671,7 +668,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
}
} // namespace common
} // namespace xgboost