Add cuda forwards compatibility (#3316)

This commit is contained in:
Rory Mitchell 2018-05-17 10:59:22 +12:00 committed by GitHub
parent f8b7686719
commit 3ee725e3bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 33 deletions

View File

@ -14,8 +14,8 @@ option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA")
option(JVM_BINDINGS "Build JVM bindings" OFF)
option(GOOGLE_TEST "Build google tests" OFF)
option(R_LIB "Build shared library for R package" OFF)
set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING
"Space separated list of compute versions to be built against")
set(GPU_COMPUTE_VER "" CACHE STRING
"Space separated list of compute versions to be built against, e.g. '35 61'")
# Deprecation warning
if(PLUGIN_UPDATER_GPU)
@ -122,16 +122,13 @@ if(USE_CUDA)
add_definitions(-DXGBOOST_USE_NCCL)
endif()
if((CUDA_VERSION_MAJOR EQUAL 9) OR (CUDA_VERSION_MAJOR GREATER 9))
message("CUDA 9.0 detected, adding Volta compute capability (7.0).")
set(GPU_COMPUTE_VER "${GPU_COMPUTE_VER};70")
endif()
set(GENCODE_FLAGS "")
format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS)
message("cuda architecture flags: ${GENCODE_FLAGS}")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;--expt-relaxed-constexpr;${GENCODE_FLAGS};-lineinfo;")
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC; -std=c++11")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC; -Xcompiler -Werror; -std=c++11")
endif()
if(USE_NCCL)

View File

@ -54,10 +54,25 @@ function(set_default_configuration_release)
endif()
endfunction(set_default_configuration_release)
# Generate nvcc compiler flags given a list of architectures
# Also generates PTX for the most recent architecture for forwards compatibility
function(format_gencode_flags flags out)
# Set up architecture flags
if(NOT flags)
if((CUDA_VERSION_MAJOR EQUAL 9) OR (CUDA_VERSION_MAJOR GREATER 9))
set(flags "35;50;52;60;61;70")
else()
set(flags "35;50;52;60;61")
endif()
endif()
# Generate SASS
foreach(ver ${flags})
set(${out} "${${out}}-gencode arch=compute_${ver},code=sm_${ver};")
endforeach()
# Generate PTX for last architecture
list(GET flags -1 ver)
set(${out} "${${out}}-gencode arch=compute_${ver},code=compute_${ver};")
set(${out} "${${out}}" PARENT_SCOPE)
endfunction(format_gencode_flags flags)

View File

@ -25,31 +25,6 @@ void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
}
}
void SpeedTest() {
int num_rows = 1000000;
int max_row_size = 100;
dh::CubMemory temp_memory;
thrust::host_vector<int> h_row_ptr;
thrust::host_vector<xgboost::bst_uint> h_rows;
CreateTestData(num_rows, max_row_size, &h_row_ptr, &h_rows);
thrust::device_vector<int> row_ptr = h_row_ptr;
thrust::device_vector<int> output_row(h_rows.size());
auto d_output_row = output_row.data();
xgboost::common::Timer t;
dh::TransformLbs(
0, &temp_memory, h_rows.size(), dh::Raw(row_ptr), row_ptr.size() - 1,
false,
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
dh::safe_cuda(cudaDeviceSynchronize());
double time = t.ElapsedSeconds();
const int mb_size = 1048576;
size_t size = (sizeof(int) * h_rows.size()) / mb_size;
printf("size: %llumb, time: %fs, bandwidth: %fmb/s\n", size, time,
size / time);
}
void TestLbs() {
srand(17);
dh::CubMemory temp_memory;