Support CUDA f16 without transformation. (#9207)

- Support f16 from cupy.
- Include CUDA header explicitly.
- Cleanup cmake nvtx support.
This commit is contained in:
Jiaming Yuan
2023-05-30 20:54:31 +08:00
committed by GitHub
parent 6f83d9c69a
commit 097f11b6e0
5 changed files with 27 additions and 54 deletions

View File

@@ -124,13 +124,6 @@ function(format_gencode_flags flags out)
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")
endfunction(format_gencode_flags flags)
macro(enable_nvtx target)
find_package(NVTX REQUIRED)
target_include_directories(${target} PRIVATE "${NVTX_INCLUDE_DIR}")
target_link_libraries(${target} PRIVATE "${NVTX_LIBRARY}")
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NVTX=1)
endmacro()
# Set CUDA related flags to target. Must be used after code `format_gencode_flags`.
function(xgboost_set_cuda_flags target)
target_compile_options(${target} PRIVATE
@@ -162,11 +155,14 @@ function(xgboost_set_cuda_flags target)
endif (USE_DEVICE_DEBUG)
if (USE_NVTX)
enable_nvtx(${target})
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NVTX=1)
endif (USE_NVTX)
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap)
target_include_directories(
${target} PRIVATE
${xgboost_SOURCE_DIR}/gputreeshap
${CUDAToolkit_INCLUDE_DIRS})
if (MSVC)
target_compile_options(${target} PRIVATE
@@ -289,7 +285,7 @@ macro(xgboost_target_link_libraries target)
endif (USE_NCCL)
if (USE_NVTX)
enable_nvtx(${target})
target_link_libraries(${target} PRIVATE CUDA::nvToolsExt)
endif (USE_NVTX)
if (RABIT_BUILD_MPI)

View File

@@ -1,26 +0,0 @@
if (NVTX_LIBRARY)
unset(NVTX_LIBRARY CACHE)
endif (NVTX_LIBRARY)
set(NVTX_LIB_NAME nvToolsExt)
find_path(NVTX_INCLUDE_DIR
NAMES nvToolsExt.h
PATHS ${CUDA_HOME}/include ${CUDA_INCLUDE} /usr/local/cuda/include)
find_library(NVTX_LIBRARY
NAMES nvToolsExt
PATHS ${CUDA_HOME}/lib64 /usr/local/cuda/lib64)
message(STATUS "Using nvtx library: ${NVTX_LIBRARY}")
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NVTX DEFAULT_MSG
NVTX_INCLUDE_DIR NVTX_LIBRARY)
mark_as_advanced(
NVTX_INCLUDE_DIR
NVTX_LIBRARY
)