Support CUDA f16 without transformation. (#9207)
- Support f16 from cupy. - Include CUDA header explicitly. - Cleanup cmake nvtx support.
This commit is contained in:
@@ -124,13 +124,6 @@ function(format_gencode_flags flags out)
|
||||
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")
|
||||
endfunction(format_gencode_flags flags)
|
||||
|
||||
macro(enable_nvtx target)
|
||||
find_package(NVTX REQUIRED)
|
||||
target_include_directories(${target} PRIVATE "${NVTX_INCLUDE_DIR}")
|
||||
target_link_libraries(${target} PRIVATE "${NVTX_LIBRARY}")
|
||||
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NVTX=1)
|
||||
endmacro()
|
||||
|
||||
# Set CUDA related flags to target. Must be used after code `format_gencode_flags`.
|
||||
function(xgboost_set_cuda_flags target)
|
||||
target_compile_options(${target} PRIVATE
|
||||
@@ -162,11 +155,14 @@ function(xgboost_set_cuda_flags target)
|
||||
endif (USE_DEVICE_DEBUG)
|
||||
|
||||
if (USE_NVTX)
|
||||
enable_nvtx(${target})
|
||||
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NVTX=1)
|
||||
endif (USE_NVTX)
|
||||
|
||||
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
|
||||
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap)
|
||||
target_include_directories(
|
||||
${target} PRIVATE
|
||||
${xgboost_SOURCE_DIR}/gputreeshap
|
||||
${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
if (MSVC)
|
||||
target_compile_options(${target} PRIVATE
|
||||
@@ -289,7 +285,7 @@ macro(xgboost_target_link_libraries target)
|
||||
endif (USE_NCCL)
|
||||
|
||||
if (USE_NVTX)
|
||||
enable_nvtx(${target})
|
||||
target_link_libraries(${target} PRIVATE CUDA::nvToolsExt)
|
||||
endif (USE_NVTX)
|
||||
|
||||
if (RABIT_BUILD_MPI)
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
if (NVTX_LIBRARY)
|
||||
unset(NVTX_LIBRARY CACHE)
|
||||
endif (NVTX_LIBRARY)
|
||||
|
||||
set(NVTX_LIB_NAME nvToolsExt)
|
||||
|
||||
|
||||
find_path(NVTX_INCLUDE_DIR
|
||||
NAMES nvToolsExt.h
|
||||
PATHS ${CUDA_HOME}/include ${CUDA_INCLUDE} /usr/local/cuda/include)
|
||||
|
||||
|
||||
find_library(NVTX_LIBRARY
|
||||
NAMES nvToolsExt
|
||||
PATHS ${CUDA_HOME}/lib64 /usr/local/cuda/lib64)
|
||||
|
||||
message(STATUS "Using nvtx library: ${NVTX_LIBRARY}")
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NVTX DEFAULT_MSG
|
||||
NVTX_INCLUDE_DIR NVTX_LIBRARY)
|
||||
|
||||
mark_as_advanced(
|
||||
NVTX_INCLUDE_DIR
|
||||
NVTX_LIBRARY
|
||||
)
|
||||
Reference in New Issue
Block a user