Run training with empty DMatrix. (#4990)
This makes GPU Hist robust in distributed environment as some workers might not be associated with any data in either training or evaluation. * Disable rabit mock test for now: See #5012 . * Disable dask-cudf test at prediction for now: See #5003 * Launch dask job for all workers despite they might not have any data. * Check 0 rows in elementwise evaluation metrics. Using AUC and AUC-PR still throws an error. See #4663 for a robust fix. * Add tests for edge cases. * Add `LaunchKernel` wrapper handling zero sized grid. * Move some parts of allreducer into a cu file. * Don't validate feature names when the booster is empty. * Sync number of columns in DMatrix. As num_feature is required to be the same across all workers in data split mode. * Filtering in dask interface now by default syncs all booster that's not empty, instead of using rank 0. * Fix Jenkins' GPU tests. * Install dask-cuda from source in Jenkins' test. Now all tests are actually running. * Restore GPU Hist tree synchronization test. * Check UUID of running devices. The check is only performed on CUDA version >= 10.x, as 9.x doesn't have UUID field. * Fix CMake policy and project variables. Use xgboost_SOURCE_DIR uniformly, add policy for CMake >= 3.13. * Fix copying data to CPU * Fix race condition in cpu predictor. * Fix duplicated DMatrix construction. * Don't download extra nccl in CI script.
This commit is contained in:
parent
807a244517
commit
7663de956c
@ -1,9 +1,13 @@
|
|||||||
cmake_minimum_required(VERSION 3.3)
|
cmake_minimum_required(VERSION 3.3)
|
||||||
project(xgboost LANGUAGES CXX C VERSION 1.0.0)
|
project(xgboost LANGUAGES CXX C VERSION 1.0.0)
|
||||||
include(cmake/Utils.cmake)
|
include(cmake/Utils.cmake)
|
||||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules")
|
list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules")
|
||||||
cmake_policy(SET CMP0022 NEW)
|
cmake_policy(SET CMP0022 NEW)
|
||||||
|
|
||||||
|
if ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13))
|
||||||
|
cmake_policy(SET CMP0077 NEW)
|
||||||
|
endif ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13))
|
||||||
|
|
||||||
message(STATUS "CMake version ${CMAKE_VERSION}")
|
message(STATUS "CMake version ${CMAKE_VERSION}")
|
||||||
if (MSVC)
|
if (MSVC)
|
||||||
cmake_minimum_required(VERSION 3.11)
|
cmake_minimum_required(VERSION 3.11)
|
||||||
@ -84,7 +88,7 @@ endif (USE_CUDA)
|
|||||||
|
|
||||||
# dmlc-core
|
# dmlc-core
|
||||||
msvc_use_static_runtime()
|
msvc_use_static_runtime()
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/dmlc-core)
|
add_subdirectory(${xgboost_SOURCE_DIR}/dmlc-core)
|
||||||
set_target_properties(dmlc PROPERTIES
|
set_target_properties(dmlc PROPERTIES
|
||||||
CXX_STANDARD 11
|
CXX_STANDARD 11
|
||||||
CXX_STANDARD_REQUIRED ON
|
CXX_STANDARD_REQUIRED ON
|
||||||
@ -105,7 +109,7 @@ endif(RABIT_MOCK)
|
|||||||
|
|
||||||
# Exports some R specific definitions and objects
|
# Exports some R specific definitions and objects
|
||||||
if (R_LIB)
|
if (R_LIB)
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/R-package)
|
add_subdirectory(${xgboost_SOURCE_DIR}/R-package)
|
||||||
endif (R_LIB)
|
endif (R_LIB)
|
||||||
|
|
||||||
# core xgboost
|
# core xgboost
|
||||||
@ -123,22 +127,23 @@ target_link_libraries(xgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE})
|
|||||||
|
|
||||||
# This creates its own shared library `xgboost4j'.
|
# This creates its own shared library `xgboost4j'.
|
||||||
if (JVM_BINDINGS)
|
if (JVM_BINDINGS)
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/jvm-packages)
|
add_subdirectory(${xgboost_SOURCE_DIR}/jvm-packages)
|
||||||
endif (JVM_BINDINGS)
|
endif (JVM_BINDINGS)
|
||||||
#-- End shared library
|
#-- End shared library
|
||||||
|
|
||||||
#-- CLI for xgboost
|
#-- CLI for xgboost
|
||||||
add_executable(runxgboost ${PROJECT_SOURCE_DIR}/src/cli_main.cc ${XGBOOST_OBJ_SOURCES})
|
add_executable(runxgboost ${xgboost_SOURCE_DIR}/src/cli_main.cc ${XGBOOST_OBJ_SOURCES})
|
||||||
# For cli_main.cc only
|
# For cli_main.cc only
|
||||||
if (USE_OPENMP)
|
if (USE_OPENMP)
|
||||||
find_package(OpenMP REQUIRED)
|
find_package(OpenMP REQUIRED)
|
||||||
target_compile_options(runxgboost PRIVATE ${OpenMP_CXX_FLAGS})
|
target_compile_options(runxgboost PRIVATE ${OpenMP_CXX_FLAGS})
|
||||||
endif (USE_OPENMP)
|
endif (USE_OPENMP)
|
||||||
|
|
||||||
target_include_directories(runxgboost
|
target_include_directories(runxgboost
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${PROJECT_SOURCE_DIR}/include
|
${xgboost_SOURCE_DIR}/include
|
||||||
${PROJECT_SOURCE_DIR}/dmlc-core/include
|
${xgboost_SOURCE_DIR}/dmlc-core/include
|
||||||
${PROJECT_SOURCE_DIR}/rabit/include)
|
${xgboost_SOURCE_DIR}/rabit/include)
|
||||||
target_link_libraries(runxgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE})
|
target_link_libraries(runxgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE})
|
||||||
set_target_properties(
|
set_target_properties(
|
||||||
runxgboost PROPERTIES
|
runxgboost PROPERTIES
|
||||||
@ -147,8 +152,8 @@ set_target_properties(
|
|||||||
CXX_STANDARD_REQUIRED ON)
|
CXX_STANDARD_REQUIRED ON)
|
||||||
#-- End CLI for xgboost
|
#-- End CLI for xgboost
|
||||||
|
|
||||||
set_output_directory(runxgboost ${PROJECT_SOURCE_DIR})
|
set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
|
||||||
set_output_directory(xgboost ${PROJECT_SOURCE_DIR}/lib)
|
set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
|
||||||
# Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
|
# Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
|
||||||
add_dependencies(xgboost runxgboost)
|
add_dependencies(xgboost runxgboost)
|
||||||
|
|
||||||
@ -205,21 +210,21 @@ install(
|
|||||||
if (GOOGLE_TEST)
|
if (GOOGLE_TEST)
|
||||||
enable_testing()
|
enable_testing()
|
||||||
# Unittests.
|
# Unittests.
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/tests/cpp)
|
add_subdirectory(${xgboost_SOURCE_DIR}/tests/cpp)
|
||||||
add_test(
|
add_test(
|
||||||
NAME TestXGBoostLib
|
NAME TestXGBoostLib
|
||||||
COMMAND testxgboost
|
COMMAND testxgboost
|
||||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
|
WORKING_DIRECTORY ${xgboost_BINARY_DIR})
|
||||||
|
|
||||||
# CLI tests
|
# CLI tests
|
||||||
configure_file(
|
configure_file(
|
||||||
${PROJECT_SOURCE_DIR}/tests/cli/machine.conf.in
|
${xgboost_SOURCE_DIR}/tests/cli/machine.conf.in
|
||||||
${PROJECT_BINARY_DIR}/tests/cli/machine.conf
|
${xgboost_BINARY_DIR}/tests/cli/machine.conf
|
||||||
@ONLY)
|
@ONLY)
|
||||||
add_test(
|
add_test(
|
||||||
NAME TestXGBoostCLI
|
NAME TestXGBoostCLI
|
||||||
COMMAND runxgboost ${PROJECT_BINARY_DIR}/tests/cli/machine.conf
|
COMMAND runxgboost ${xgboost_BINARY_DIR}/tests/cli/machine.conf
|
||||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
|
WORKING_DIRECTORY ${xgboost_BINARY_DIR})
|
||||||
set_tests_properties(TestXGBoostCLI
|
set_tests_properties(TestXGBoostCLI
|
||||||
PROPERTIES
|
PROPERTIES
|
||||||
PASS_REGULAR_EXPRESSION ".*test-rmse:0.087.*")
|
PASS_REGULAR_EXPRESSION ".*test-rmse:0.087.*")
|
||||||
|
|||||||
1
Jenkinsfile
vendored
1
Jenkinsfile
vendored
@ -83,7 +83,6 @@ pipeline {
|
|||||||
'test-python-gpu-cuda10.0': { TestPythonGPU(cuda_version: '10.0') },
|
'test-python-gpu-cuda10.0': { TestPythonGPU(cuda_version: '10.0') },
|
||||||
'test-python-gpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1') },
|
'test-python-gpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1') },
|
||||||
'test-python-mgpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1', multi_gpu: true) },
|
'test-python-mgpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1', multi_gpu: true) },
|
||||||
'test-cpp-rabit': {TestCppRabit()},
|
|
||||||
'test-cpp-gpu': { TestCppGPU(cuda_version: '10.1') },
|
'test-cpp-gpu': { TestCppGPU(cuda_version: '10.1') },
|
||||||
'test-cpp-mgpu': { TestCppGPU(cuda_version: '10.1', multi_gpu: true) },
|
'test-cpp-mgpu': { TestCppGPU(cuda_version: '10.1', multi_gpu: true) },
|
||||||
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '2.4.3') },
|
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '2.4.3') },
|
||||||
|
|||||||
@ -6,7 +6,7 @@ function (run_doxygen)
|
|||||||
endif (NOT DOXYGEN_DOT_FOUND)
|
endif (NOT DOXYGEN_DOT_FOUND)
|
||||||
|
|
||||||
configure_file(
|
configure_file(
|
||||||
${PROJECT_SOURCE_DIR}/doc/Doxyfile.in
|
${xgboost_SOURCE_DIR}/doc/Doxyfile.in
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY)
|
${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY)
|
||||||
add_custom_target( doc_doxygen ALL
|
add_custom_target( doc_doxygen ALL
|
||||||
COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile
|
COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile
|
||||||
|
|||||||
@ -111,7 +111,7 @@ DESTINATION \"${build_dir}/bak\")")
|
|||||||
|
|
||||||
install(CODE "file(REMOVE_RECURSE \"${build_dir}/R-package\")")
|
install(CODE "file(REMOVE_RECURSE \"${build_dir}/R-package\")")
|
||||||
install(
|
install(
|
||||||
DIRECTORY "${PROJECT_SOURCE_DIR}/R-package"
|
DIRECTORY "${xgboost_SOURCE_DIR}/R-package"
|
||||||
DESTINATION "${build_dir}"
|
DESTINATION "${build_dir}"
|
||||||
REGEX "src/*" EXCLUDE
|
REGEX "src/*" EXCLUDE
|
||||||
REGEX "R-package/configure" EXCLUDE
|
REGEX "R-package/configure" EXCLUDE
|
||||||
|
|||||||
@ -5,6 +5,5 @@ function (write_version)
|
|||||||
${xgboost_SOURCE_DIR}/include/xgboost/version_config.h @ONLY)
|
${xgboost_SOURCE_DIR}/include/xgboost/version_config.h @ONLY)
|
||||||
configure_file(
|
configure_file(
|
||||||
${xgboost_SOURCE_DIR}/cmake/Python_version.in
|
${xgboost_SOURCE_DIR}/cmake/Python_version.in
|
||||||
${xgboost_SOURCE_DIR}/python-package/xgboost/VERSION
|
${xgboost_SOURCE_DIR}/python-package/xgboost/VERSION)
|
||||||
)
|
|
||||||
endfunction (write_version)
|
endfunction (write_version)
|
||||||
|
|||||||
23
cmake/modules/FindNVML.cmake
Normal file
23
cmake/modules/FindNVML.cmake
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
if (NVML_LIBRARY)
|
||||||
|
unset(NVML_LIBRARY CACHE)
|
||||||
|
endif(NVML_LIBRARY)
|
||||||
|
|
||||||
|
set(NVML_LIB_NAME nvml)
|
||||||
|
|
||||||
|
find_path(NVML_INCLUDE_DIR
|
||||||
|
NAMES nvml.h
|
||||||
|
PATHS ${CUDA_HOME}/include ${CUDA_INCLUDE} /usr/local/cuda/include)
|
||||||
|
|
||||||
|
find_library(NVML_LIBRARY
|
||||||
|
NAMES nvidia-ml)
|
||||||
|
|
||||||
|
message(STATUS "Using nvml library: ${NVML_LIBRARY}")
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NVML DEFAULT_MSG
|
||||||
|
NVML_INCLUDE_DIR NVML_LIBRARY)
|
||||||
|
|
||||||
|
mark_as_advanced(
|
||||||
|
NVML_INCLUDE_DIR
|
||||||
|
NVML_LIBRARY
|
||||||
|
)
|
||||||
@ -513,7 +513,7 @@ class DMatrix(object):
|
|||||||
try:
|
try:
|
||||||
csr = scipy.sparse.csr_matrix(data)
|
csr = scipy.sparse.csr_matrix(data)
|
||||||
self._init_from_csr(csr)
|
self._init_from_csr(csr)
|
||||||
except:
|
except Exception:
|
||||||
raise TypeError('can not initialize DMatrix from'
|
raise TypeError('can not initialize DMatrix from'
|
||||||
' {}'.format(type(data).__name__))
|
' {}'.format(type(data).__name__))
|
||||||
|
|
||||||
@ -577,9 +577,9 @@ class DMatrix(object):
|
|||||||
if len(mat.shape) != 2:
|
if len(mat.shape) != 2:
|
||||||
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
||||||
mat.shape)
|
mat.shape)
|
||||||
# flatten the array by rows and ensure it is float32.
|
# flatten the array by rows and ensure it is float32. we try to avoid
|
||||||
# we try to avoid data copies if possible (reshape returns a view when possible
|
# data copies if possible (reshape returns a view when possible and we
|
||||||
# and we explicitly tell np.array to try and avoid copying)
|
# explicitly tell np.array to try and avoid copying)
|
||||||
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
missing = missing if missing is not None else np.nan
|
missing = missing if missing is not None else np.nan
|
||||||
@ -1391,8 +1391,9 @@ class Booster(object):
|
|||||||
value of the prediction. Note the last row and column correspond to the bias term.
|
value of the prediction. Note the last row and column correspond to the bias term.
|
||||||
|
|
||||||
validate_features : bool
|
validate_features : bool
|
||||||
When this is True, validate that the Booster's and data's feature_names are identical.
|
When this is True, validate that the Booster's and data's
|
||||||
Otherwise, it is assumed that the feature_names are the same.
|
feature_names are identical. Otherwise, it is assumed that the
|
||||||
|
feature_names are the same.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -1811,8 +1812,8 @@ class Booster(object):
|
|||||||
msg = 'feature_names mismatch: {0} {1}'
|
msg = 'feature_names mismatch: {0} {1}'
|
||||||
|
|
||||||
if dat_missing:
|
if dat_missing:
|
||||||
msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) +
|
msg += ('\nexpected ' + ', '.join(
|
||||||
' in input data')
|
str(s) for s in dat_missing) + ' in input data')
|
||||||
|
|
||||||
if my_missing:
|
if my_missing:
|
||||||
msg += ('\ntraining data did not have the following fields: ' +
|
msg += ('\ntraining data did not have the following fields: ' +
|
||||||
@ -1821,7 +1822,8 @@ class Booster(object):
|
|||||||
raise ValueError(msg.format(self.feature_names,
|
raise ValueError(msg.format(self.feature_names,
|
||||||
data.feature_names))
|
data.feature_names))
|
||||||
|
|
||||||
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
|
def get_split_value_histogram(self, feature, fmap='', bins=None,
|
||||||
|
as_pandas=True):
|
||||||
"""Get split value histogram of a feature
|
"""Get split value histogram of a feature
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@ -55,10 +55,14 @@ def _start_tracker(host, n_workers):
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def _assert_dask_installed():
|
def _assert_dask_support():
|
||||||
if not DASK_INSTALLED:
|
if not DASK_INSTALLED:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Dask needs to be installed in order to use this module')
|
'Dask needs to be installed in order to use this module')
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||||
|
msg += ' contribution are welcomed.'
|
||||||
|
logging.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
class RabitContext:
|
class RabitContext:
|
||||||
@ -96,6 +100,11 @@ def _xgb_get_client(client):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_workers(client):
|
||||||
|
workers = client.scheduler_info()['workers']
|
||||||
|
return workers
|
||||||
|
|
||||||
|
|
||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
'''DMatrix holding on references to Dask DataFrame or Dask Array.
|
'''DMatrix holding on references to Dask DataFrame or Dask Array.
|
||||||
@ -132,7 +141,7 @@ class DaskDMatrix:
|
|||||||
weight=None,
|
weight=None,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None):
|
feature_types=None):
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
|
|
||||||
self._feature_names = feature_names
|
self._feature_names = feature_names
|
||||||
self._feature_types = feature_types
|
self._feature_types = feature_types
|
||||||
@ -263,6 +272,17 @@ class DaskDMatrix:
|
|||||||
A DMatrix object.
|
A DMatrix object.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
if worker.address not in set(self.worker_map.keys()):
|
||||||
|
msg = 'worker {address} has an empty DMatrix. ' \
|
||||||
|
'All workers associated with this DMatrix: {workers}'.format(
|
||||||
|
address=worker.address,
|
||||||
|
workers=set(self.worker_map.keys()))
|
||||||
|
logging.warning(msg)
|
||||||
|
d = DMatrix(numpy.empty((0, 0)),
|
||||||
|
feature_names=self._feature_names,
|
||||||
|
feature_types=self._feature_types)
|
||||||
|
return d
|
||||||
|
|
||||||
data, labels, weights = self.get_worker_parts(worker)
|
data, labels, weights = self.get_worker_parts(worker)
|
||||||
|
|
||||||
data = concat(data)
|
data = concat(data)
|
||||||
@ -275,7 +295,6 @@ class DaskDMatrix:
|
|||||||
weights = concat(weights)
|
weights = concat(weights)
|
||||||
else:
|
else:
|
||||||
weights = None
|
weights = None
|
||||||
|
|
||||||
dmatrix = DMatrix(data,
|
dmatrix = DMatrix(data,
|
||||||
labels,
|
labels,
|
||||||
weight=weights,
|
weight=weights,
|
||||||
@ -342,35 +361,33 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
'eval': {'logloss': ['0.480385', '0.357756']}}}
|
'eval': {'logloss': ['0.480385', '0.357756']}}}
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
if platform.system() == 'Windows':
|
|
||||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
|
||||||
msg += ' contribution are welcomed.'
|
|
||||||
logging.warning(msg)
|
|
||||||
|
|
||||||
if 'evals_result' in kwargs.keys():
|
if 'evals_result' in kwargs.keys():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'evals_result is not supported in dask interface.',
|
'evals_result is not supported in dask interface.',
|
||||||
'The evaluation history is returned as result of training.')
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
|
||||||
worker_map = dtrain.worker_map
|
rabit_args = _get_rabit_args(workers, client)
|
||||||
rabit_args = _get_rabit_args(worker_map, client)
|
|
||||||
|
|
||||||
def dispatched_train(worker_id):
|
def dispatched_train(worker_addr):
|
||||||
'''Perform training on worker.'''
|
'''Perform training on a single worker.'''
|
||||||
logging.info('Training on %d', worker_id)
|
logging.info('Training on %s', str(worker_addr))
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
local_dtrain = dtrain.get_worker_data(worker)
|
|
||||||
|
|
||||||
local_evals = []
|
|
||||||
if evals:
|
|
||||||
for mat, name in evals:
|
|
||||||
local_mat = mat.get_worker_data(worker)
|
|
||||||
local_evals.append((local_mat, name))
|
|
||||||
|
|
||||||
with RabitContext(rabit_args):
|
with RabitContext(rabit_args):
|
||||||
|
local_dtrain = dtrain.get_worker_data(worker)
|
||||||
|
|
||||||
|
local_evals = []
|
||||||
|
if evals:
|
||||||
|
for mat, name in evals:
|
||||||
|
if mat is dtrain:
|
||||||
|
local_evals.append((local_dtrain, name))
|
||||||
|
continue
|
||||||
|
local_mat = mat.get_worker_data(worker)
|
||||||
|
local_evals.append((local_mat, name))
|
||||||
|
|
||||||
local_history = {}
|
local_history = {}
|
||||||
local_param = params.copy() # just to be consistent
|
local_param = params.copy() # just to be consistent
|
||||||
bst = worker_train(params=local_param,
|
bst = worker_train(params=local_param,
|
||||||
@ -380,14 +397,14 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
evals=local_evals,
|
evals=local_evals,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
ret = {'booster': bst, 'history': local_history}
|
ret = {'booster': bst, 'history': local_history}
|
||||||
if rabit.get_rank() != 0:
|
if local_dtrain.num_row() == 0:
|
||||||
ret = None
|
ret = None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
futures = client.map(dispatched_train,
|
futures = client.map(dispatched_train,
|
||||||
range(len(worker_map)),
|
workers,
|
||||||
pure=False,
|
pure=False,
|
||||||
workers=list(worker_map.keys()))
|
workers=workers)
|
||||||
results = client.gather(futures)
|
results = client.gather(futures)
|
||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
@ -414,7 +431,7 @@ def predict(client, model, data, *args):
|
|||||||
prediction: dask.array.Array
|
prediction: dask.array.Array
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
@ -437,7 +454,8 @@ def predict(client, model, data, *args):
|
|||||||
local_x = data.get_worker_data(worker)
|
local_x = data.get_worker_data(worker)
|
||||||
|
|
||||||
with RabitContext(rabit_args):
|
with RabitContext(rabit_args):
|
||||||
local_predictions = booster.predict(data=local_x, *args)
|
local_predictions = booster.predict(
|
||||||
|
data=local_x, validate_features=local_x.num_row() != 0, *args)
|
||||||
return local_predictions
|
return local_predictions
|
||||||
|
|
||||||
futures = client.map(dispatched_predict,
|
futures = client.map(dispatched_predict,
|
||||||
@ -563,7 +581,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
|
|||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None):
|
sample_weight_eval_set=None):
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
dtrain = DaskDMatrix(client=self.client,
|
dtrain = DaskDMatrix(client=self.client,
|
||||||
data=X, label=y, weight=sample_weights)
|
data=X, label=y, weight=sample_weights)
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
@ -579,7 +597,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data): # pylint: disable=arguments-differ
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
||||||
pred_probs = predict(client=self.client,
|
pred_probs = predict(client=self.client,
|
||||||
model=self.get_booster(), data=test_dmatrix)
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
@ -599,7 +617,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None):
|
sample_weight_eval_set=None):
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
dtrain = DaskDMatrix(client=self.client,
|
dtrain = DaskDMatrix(client=self.client,
|
||||||
data=X, label=y, weight=sample_weights)
|
data=X, label=y, weight=sample_weights)
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
@ -626,7 +644,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data): # pylint: disable=arguments-differ
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
_assert_dask_installed()
|
_assert_dask_support()
|
||||||
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
||||||
pred_probs = predict(client=self.client,
|
pred_probs = predict(client=self.client,
|
||||||
model=self.get_booster(), data=test_dmatrix)
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
|
|||||||
@ -332,7 +332,7 @@ class RabitTracker(object):
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
while self.thread.isAlive():
|
while self.thread.is_alive():
|
||||||
self.thread.join(100)
|
self.thread.join(100)
|
||||||
|
|
||||||
def alive(self):
|
def alive(self):
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
file(GLOB_RECURSE CPU_SOURCES *.cc *.h)
|
file(GLOB_RECURSE CPU_SOURCES *.cc *.h)
|
||||||
list(REMOVE_ITEM CPU_SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc)
|
list(REMOVE_ITEM CPU_SOURCES ${xgboost_SOURCE_DIR}/src/cli_main.cc)
|
||||||
|
|
||||||
#-- Object library
|
#-- Object library
|
||||||
# Object library is necessary for jvm-package, which creates its own shared
|
# Object library is necessary for jvm-package, which creates its own shared
|
||||||
@ -9,7 +9,7 @@ if (USE_CUDA)
|
|||||||
add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES})
|
add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES})
|
||||||
target_compile_definitions(objxgboost
|
target_compile_definitions(objxgboost
|
||||||
PRIVATE -DXGBOOST_USE_CUDA=1)
|
PRIVATE -DXGBOOST_USE_CUDA=1)
|
||||||
target_include_directories(objxgboost PRIVATE ${PROJECT_SOURCE_DIR}/cub/)
|
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/)
|
||||||
target_compile_options(objxgboost PRIVATE
|
target_compile_options(objxgboost PRIVATE
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
|
||||||
@ -43,9 +43,9 @@ endif (USE_CUDA)
|
|||||||
|
|
||||||
target_include_directories(objxgboost
|
target_include_directories(objxgboost
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${PROJECT_SOURCE_DIR}/include
|
${xgboost_SOURCE_DIR}/include
|
||||||
${PROJECT_SOURCE_DIR}/dmlc-core/include
|
${xgboost_SOURCE_DIR}/dmlc-core/include
|
||||||
${PROJECT_SOURCE_DIR}/rabit/include)
|
${xgboost_SOURCE_DIR}/rabit/include)
|
||||||
target_compile_options(objxgboost
|
target_compile_options(objxgboost
|
||||||
PRIVATE
|
PRIVATE
|
||||||
$<$<AND:$<CXX_COMPILER_ID:MSVC>,$<COMPILE_LANGUAGE:CXX>>:/MP>
|
$<$<AND:$<CXX_COMPILER_ID:MSVC>,$<COMPILE_LANGUAGE:CXX>>:/MP>
|
||||||
|
|||||||
91
src/common/device_helpers.cu
Normal file
91
src/common/device_helpers.cu
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2019 XGBoost contributors
|
||||||
|
*
|
||||||
|
* \brief Utilities for CUDA.
|
||||||
|
*/
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
#include <nccl.h>
|
||||||
|
#endif // #ifdef XGBOOST_USE_NCCL
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
|
||||||
|
namespace dh {
|
||||||
|
|
||||||
|
#if __CUDACC_VER_MAJOR__ > 9
|
||||||
|
constexpr std::size_t kUuidLength =
|
||||||
|
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||||
|
|
||||||
|
void GetCudaUUID(int world_size, int rank, int device_ord,
|
||||||
|
xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||||
|
cudaDeviceProp prob;
|
||||||
|
safe_cuda(cudaGetDeviceProperties(&prob, device_ord));
|
||||||
|
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||||
|
std::stringstream ss;
|
||||||
|
for (auto v : uuid) {
|
||||||
|
ss << std::hex << v;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // __CUDACC_VER_MAJOR__ > 9
|
||||||
|
|
||||||
|
void AllReducer::Init(int _device_ordinal) {
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
LOG(DEBUG) << "Running nccl init on: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__;
|
||||||
|
|
||||||
|
device_ordinal = _device_ordinal;
|
||||||
|
int32_t const rank = rabit::GetRank();
|
||||||
|
|
||||||
|
#if __CUDACC_VER_MAJOR__ > 9
|
||||||
|
int32_t const world = rabit::GetWorldSize();
|
||||||
|
|
||||||
|
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||||
|
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||||
|
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||||
|
GetCudaUUID(world, rank, device_ordinal, s_this_uuid);
|
||||||
|
|
||||||
|
// No allgather yet.
|
||||||
|
rabit::Allreduce<rabit::op::Sum, uint64_t>(uuids.data(), uuids.size());
|
||||||
|
|
||||||
|
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);;
|
||||||
|
size_t j = 0;
|
||||||
|
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||||
|
converted[j] =
|
||||||
|
xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = std::unique(converted.begin(), converted.end());
|
||||||
|
auto n_uniques = std::distance(converted.begin(), iter);
|
||||||
|
CHECK_EQ(n_uniques, world)
|
||||||
|
<< "Multiple processes within communication group running on same CUDA "
|
||||||
|
<< "device is not supported";
|
||||||
|
#endif // __CUDACC_VER_MAJOR__ > 9
|
||||||
|
|
||||||
|
id = GetUniqueId();
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
||||||
|
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rank));
|
||||||
|
safe_cuda(cudaStreamCreate(&stream));
|
||||||
|
initialised_ = true;
|
||||||
|
#endif // XGBOOST_USE_NCCL
|
||||||
|
}
|
||||||
|
|
||||||
|
AllReducer::~AllReducer() {
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
if (initialised_) {
|
||||||
|
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||||
|
ncclCommDestroy(comm);
|
||||||
|
}
|
||||||
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
|
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||||
|
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||||
|
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
|
||||||
|
}
|
||||||
|
#endif // XGBOOST_USE_NCCL
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dh
|
||||||
@ -7,24 +7,25 @@
|
|||||||
#include <thrust/device_malloc_allocator.h>
|
#include <thrust/device_malloc_allocator.h>
|
||||||
#include <thrust/system/cuda/error.h>
|
#include <thrust/system/cuda/error.h>
|
||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
#include <xgboost/logging.h>
|
|
||||||
|
#include <omp.h>
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
#include <cub/util_allocator.cuh>
|
#include <cub/util_allocator.cuh>
|
||||||
|
|
||||||
#include "xgboost/host_device_vector.h"
|
|
||||||
#include "xgboost/span.h"
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <omp.h>
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
#include <cub/cub.cuh>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
#ifdef XGBOOST_USE_NCCL
|
||||||
@ -205,24 +206,53 @@ __global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
|||||||
}
|
}
|
||||||
template <typename L>
|
template <typename L>
|
||||||
__global__ void LaunchNKernel(int device_idx, size_t begin, size_t end,
|
__global__ void LaunchNKernel(int device_idx, size_t begin, size_t end,
|
||||||
L lambda) {
|
L lambda) {
|
||||||
for (auto i : GridStrideRange(begin, end)) {
|
for (auto i : GridStrideRange(begin, end)) {
|
||||||
lambda(i, device_idx);
|
lambda(i, device_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* \brief A wrapper around kernel launching syntax, used to guard against empty input.
|
||||||
|
*
|
||||||
|
* - nvcc fails to deduce template argument when kernel is a template accepting __device__
|
||||||
|
* function as argument. Hence functions like `LaunchN` cannot use this wrapper.
|
||||||
|
*
|
||||||
|
* - With c++ initialization list `{}` syntax, you are forced to comply with the CUDA type
|
||||||
|
* spcification.
|
||||||
|
*/
|
||||||
|
class LaunchKernel {
|
||||||
|
size_t shmem_size_;
|
||||||
|
cudaStream_t stream_;
|
||||||
|
|
||||||
|
dim3 grids_;
|
||||||
|
dim3 blocks_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
LaunchKernel(uint32_t _grids, uint32_t _blk, size_t _shmem=0, cudaStream_t _s=0) :
|
||||||
|
grids_{_grids, 1, 1}, blocks_{_blk, 1, 1}, shmem_size_{_shmem}, stream_{_s} {}
|
||||||
|
LaunchKernel(dim3 _grids, dim3 _blk, size_t _shmem=0, cudaStream_t _s=0) :
|
||||||
|
grids_{_grids}, blocks_{_blk}, shmem_size_{_shmem}, stream_{_s} {}
|
||||||
|
|
||||||
|
template <typename K, typename... Args>
|
||||||
|
void operator()(K kernel, Args... args) {
|
||||||
|
if (XGBOOST_EXPECT(grids_.x * grids_.y * grids_.z == 0, false)) {
|
||||||
|
LOG(DEBUG) << "Skipping empty CUDA kernel.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
kernel<<<grids_, blocks_, shmem_size_, stream_>>>(args...); // NOLINT
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||||
inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
|
inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
|
||||||
if (n == 0) {
|
if (n == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
safe_cuda(cudaSetDevice(device_idx));
|
safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
|
||||||
const int GRID_SIZE =
|
const int GRID_SIZE =
|
||||||
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
||||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0),
|
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>( // NOLINT
|
||||||
n, lambda);
|
static_cast<size_t>(0), n, lambda);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default stream version
|
// Default stream version
|
||||||
@ -301,6 +331,16 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
|
|||||||
return memory_logger;
|
return memory_logger;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dh::DebugSyncDevice(__FILE__, __LINE__);
|
||||||
|
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
|
||||||
|
if (file != "" && line != -1) {
|
||||||
|
auto rank = rabit::GetRank();
|
||||||
|
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
|
||||||
|
}
|
||||||
|
safe_cuda(cudaDeviceSynchronize());
|
||||||
|
safe_cuda(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail{
|
namespace detail{
|
||||||
/**
|
/**
|
||||||
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
|
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
|
||||||
@ -763,7 +803,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
|
|||||||
BLOCK_THREADS, segments, num_segments, count);
|
BLOCK_THREADS, segments, num_segments, count);
|
||||||
|
|
||||||
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, OffsetT>
|
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, OffsetT>
|
||||||
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates,
|
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, // NOLINT
|
||||||
segments + 1, f, num_segments);
|
segments + 1, f, num_segments);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -963,7 +1003,6 @@ class SaveCudaContext {
|
|||||||
* streams. Must be initialised before use. If XGBoost is compiled without NCCL
|
* streams. Must be initialised before use. If XGBoost is compiled without NCCL
|
||||||
* this is a dummy class that will error if used with more than one GPU.
|
* this is a dummy class that will error if used with more than one GPU.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class AllReducer {
|
class AllReducer {
|
||||||
bool initialised_;
|
bool initialised_;
|
||||||
size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated
|
size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated
|
||||||
@ -986,31 +1025,9 @@ class AllReducer {
|
|||||||
*
|
*
|
||||||
* \param device_ordinal The device ordinal.
|
* \param device_ordinal The device ordinal.
|
||||||
*/
|
*/
|
||||||
|
void Init(int _device_ordinal);
|
||||||
|
|
||||||
void Init(int _device_ordinal) {
|
~AllReducer();
|
||||||
#ifdef XGBOOST_USE_NCCL
|
|
||||||
/** \brief this >monitor . init. */
|
|
||||||
device_ordinal = _device_ordinal;
|
|
||||||
id = GetUniqueId();
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
|
||||||
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rabit::GetRank()));
|
|
||||||
safe_cuda(cudaStreamCreate(&stream));
|
|
||||||
initialised_ = true;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
~AllReducer() {
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
|
||||||
if (initialised_) {
|
|
||||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
|
||||||
ncclCommDestroy(comm);
|
|
||||||
}
|
|
||||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
|
||||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
|
||||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
|
||||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||||
|
|||||||
@ -293,6 +293,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
|
|
||||||
void DenseCuts::Init
|
void DenseCuts::Init
|
||||||
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||||
|
monitor_.Start(__func__);
|
||||||
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
||||||
constexpr int kFactor = 8;
|
constexpr int kFactor = 8;
|
||||||
// gather the histogram data
|
// gather the histogram data
|
||||||
@ -332,6 +333,7 @@ void DenseCuts::Init
|
|||||||
CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back());
|
CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back());
|
||||||
p_cuts_->cut_ptrs_.push_back(cut_size);
|
p_cuts_->cut_ptrs_.push_back(cut_size);
|
||||||
}
|
}
|
||||||
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
||||||
|
|||||||
@ -252,8 +252,10 @@ class GPUSketcher {
|
|||||||
});
|
});
|
||||||
} else if (n_cuts_cur_[icol] > 0) {
|
} else if (n_cuts_cur_[icol] > 0) {
|
||||||
// if more elements than cuts: use binary search on cumulative weights
|
// if more elements than cuts: use binary search on cumulative weights
|
||||||
int block = 256;
|
uint32_t constexpr kBlockThreads = 256;
|
||||||
FindCutsK<<<common::DivRoundUp(n_cuts_cur_[icol], block), block>>>(
|
uint32_t const kGrids = common::DivRoundUp(n_cuts_cur_[icol], kBlockThreads);
|
||||||
|
dh::LaunchKernel {kGrids, kBlockThreads} (
|
||||||
|
FindCutsK,
|
||||||
cuts_d_.data().get() + icol * n_cuts_,
|
cuts_d_.data().get() + icol * n_cuts_,
|
||||||
fvalues_cur_.data().get(),
|
fvalues_cur_.data().get(),
|
||||||
weights2_.data().get(),
|
weights2_.data().get(),
|
||||||
@ -403,7 +405,8 @@ class GPUSketcher {
|
|||||||
// NOTE: This will typically support ~ 4M features - 64K*64
|
// NOTE: This will typically support ~ 4M features - 64K*64
|
||||||
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||||
common::DivRoundUp(num_cols_, block3.y), 1);
|
common::DivRoundUp(num_cols_, block3.y), 1);
|
||||||
UnpackFeaturesK<<<grid3, block3>>>(
|
dh::LaunchKernel {grid3, block3} (
|
||||||
|
UnpackFeaturesK,
|
||||||
fvalues_.data().get(),
|
fvalues_.data().get(),
|
||||||
has_weights_ ? feature_weights_.data().get() : nullptr,
|
has_weights_ ? feature_weights_.data().get() : nullptr,
|
||||||
row_ptrs_.data().get() + batch_row_begin,
|
row_ptrs_.data().get() + batch_row_begin,
|
||||||
|
|||||||
@ -13,6 +13,20 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
|
void Monitor::Start(std::string const &name) {
|
||||||
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
statistics_map[name].timer.Start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Monitor::Stop(const std::string &name) {
|
||||||
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
auto &stats = statistics_map[name];
|
||||||
|
stats.timer.Stop();
|
||||||
|
stats.count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Monitor::StatMap> Monitor::CollectFromOtherRanks() const {
|
std::vector<Monitor::StatMap> Monitor::CollectFromOtherRanks() const {
|
||||||
// Since other nodes might have started timers that this one haven't, so
|
// Since other nodes might have started timers that this one haven't, so
|
||||||
// we can't simply call all reduce.
|
// we can't simply call all reduce.
|
||||||
|
|||||||
38
src/common/timer.cu
Normal file
38
src/common/timer.cu
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright by Contributors 2019
|
||||||
|
*/
|
||||||
|
#if defined(XGBOOST_USE_NVTX)
|
||||||
|
#include <nvToolsExt.h>
|
||||||
|
#endif // defined(XGBOOST_USE_NVTX)
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
#include "timer.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
void Monitor::StartCuda(const std::string& name) {
|
||||||
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
auto &stats = statistics_map[name];
|
||||||
|
stats.timer.Start();
|
||||||
|
#if defined(XGBOOST_USE_NVTX)
|
||||||
|
stats.nvtx_id = nvtxRangeStartA(name.c_str());
|
||||||
|
#endif // defined(XGBOOST_USE_NVTX)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Monitor::StopCuda(const std::string& name) {
|
||||||
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
auto &stats = statistics_map[name];
|
||||||
|
stats.timer.Stop();
|
||||||
|
stats.count++;
|
||||||
|
#if defined(XGBOOST_USE_NVTX)
|
||||||
|
nvtxRangeEnd(stats.nvtx_id);
|
||||||
|
#endif // defined(XGBOOST_USE_NVTX)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
@ -10,10 +10,6 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
|
||||||
#include <nvToolsExt.h>
|
|
||||||
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
@ -84,37 +80,10 @@ struct Monitor {
|
|||||||
void Print() const;
|
void Print() const;
|
||||||
|
|
||||||
void Init(std::string label) { this->label = label; }
|
void Init(std::string label) { this->label = label; }
|
||||||
void Start(const std::string &name) {
|
void Start(const std::string &name);
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
void Stop(const std::string &name);
|
||||||
statistics_map[name].timer.Start();
|
void StartCuda(const std::string &name);
|
||||||
}
|
void StopCuda(const std::string &name);
|
||||||
}
|
|
||||||
void Stop(const std::string &name) {
|
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
|
||||||
auto &stats = statistics_map[name];
|
|
||||||
stats.timer.Stop();
|
|
||||||
stats.count++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void StartCuda(const std::string &name) {
|
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
|
||||||
auto &stats = statistics_map[name];
|
|
||||||
stats.timer.Start();
|
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
|
||||||
stats.nvtx_id = nvtxRangeStartA(name.c_str());
|
|
||||||
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void StopCuda(const std::string &name) {
|
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
|
||||||
auto &stats = statistics_map[name];
|
|
||||||
stats.timer.Stop();
|
|
||||||
stats.count++;
|
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
|
||||||
nvtxRangeEnd(stats.nvtx_id);
|
|
||||||
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -133,9 +133,12 @@ class Transform {
|
|||||||
size_t shard_size = range_size;
|
size_t shard_size = range_size;
|
||||||
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
const int GRID_SIZE =
|
const int kGrids =
|
||||||
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
|
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
|
||||||
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
|
if (kGrids == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
detail::LaunchCUDAKernel<<<kGrids, kBlockThreads>>>( // NOLINT
|
||||||
_func, shard_range, UnpackHDVOnDevice(_vectors)...);
|
_func, shard_range, UnpackHDVOnDevice(_vectors)...);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|||||||
@ -320,6 +320,32 @@ void DMatrix::SaveToLocalFile(const std::string& fname) {
|
|||||||
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
|
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||||
const std::string& cache_prefix) {
|
const std::string& cache_prefix) {
|
||||||
if (cache_prefix.length() == 0) {
|
if (cache_prefix.length() == 0) {
|
||||||
|
// FIXME(trivialfis): Currently distcol is broken so we here check for number of rows.
|
||||||
|
// If we bring back column split this check will break.
|
||||||
|
bool is_distributed { rabit::IsDistributed() };
|
||||||
|
if (is_distributed) {
|
||||||
|
auto world_size = rabit::GetWorldSize();
|
||||||
|
auto rank = rabit::GetRank();
|
||||||
|
std::vector<uint64_t> ncols(world_size, 0);
|
||||||
|
ncols[rank] = source->info.num_col_;
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(ncols.data(), ncols.size());
|
||||||
|
auto max_cols = std::max_element(ncols.cbegin(), ncols.cend());
|
||||||
|
auto max_ind = std::distance(ncols.cbegin(), max_cols);
|
||||||
|
// FIXME(trivialfis): This is a hack, we should store a reference to global shape if possible.
|
||||||
|
if (source->info.num_col_ == 0 && source->info.num_row_ == 0) {
|
||||||
|
LOG(WARNING) << "DMatrix at rank: " << rank << " worker is empty.";
|
||||||
|
source->info.num_col_ = *max_cols;
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate the number of columns across all workers.
|
||||||
|
for (size_t i = 0; i < ncols.size(); ++i) {
|
||||||
|
auto v = ncols[i];
|
||||||
|
CHECK(v == 0 || v == *max_cols)
|
||||||
|
<< "DMatrix at rank: " << i << " worker "
|
||||||
|
<< "has different number of columns than rank: " << max_ind << " worker. "
|
||||||
|
<< "(" << v << " vs. " << *max_cols << ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
return new data::SimpleDMatrix(std::move(source));
|
return new data::SimpleDMatrix(std::move(source));
|
||||||
} else {
|
} else {
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
#if DMLC_ENABLE_STD_THREAD
|
||||||
|
|||||||
@ -99,13 +99,13 @@ EllpackInfo::EllpackInfo(int device,
|
|||||||
bool is_dense,
|
bool is_dense,
|
||||||
size_t row_stride,
|
size_t row_stride,
|
||||||
const common::HistogramCuts& hmat,
|
const common::HistogramCuts& hmat,
|
||||||
dh::BulkAllocator& ba)
|
dh::BulkAllocator* ba)
|
||||||
: is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) {
|
: is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) {
|
||||||
|
|
||||||
ba.Allocate(device,
|
ba->Allocate(device,
|
||||||
&feature_segments, hmat.Ptrs().size(),
|
&feature_segments, hmat.Ptrs().size(),
|
||||||
&gidx_fvalue_map, hmat.Values().size(),
|
&gidx_fvalue_map, hmat.Values().size(),
|
||||||
&min_fvalue, hmat.MinValues().size());
|
&min_fvalue, hmat.MinValues().size());
|
||||||
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
|
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
|
||||||
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
|
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
|
||||||
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
|
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
|
||||||
@ -116,7 +116,7 @@ void EllpackPageImpl::InitInfo(int device,
|
|||||||
bool is_dense,
|
bool is_dense,
|
||||||
size_t row_stride,
|
size_t row_stride,
|
||||||
const common::HistogramCuts& hmat) {
|
const common::HistogramCuts& hmat) {
|
||||||
matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, ba_);
|
matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, &ba_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the buffer to stored compressed features.
|
// Initialize the buffer to stored compressed features.
|
||||||
@ -189,7 +189,8 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
|||||||
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||||
common::DivRoundUp(row_stride, block3.y),
|
common::DivRoundUp(row_stride, block3.y),
|
||||||
1);
|
1);
|
||||||
CompressBinEllpackKernel<<<grid3, block3>>>(
|
dh::LaunchKernel {grid3, block3} (
|
||||||
|
CompressBinEllpackKernel,
|
||||||
common::CompressedBufferWriter(num_symbols),
|
common::CompressedBufferWriter(num_symbols),
|
||||||
gidx_buffer.data(),
|
gidx_buffer.data(),
|
||||||
row_ptrs.data().get(),
|
row_ptrs.data().get(),
|
||||||
|
|||||||
@ -70,7 +70,7 @@ struct EllpackInfo {
|
|||||||
bool is_dense,
|
bool is_dense,
|
||||||
size_t row_stride,
|
size_t row_stride,
|
||||||
const common::HistogramCuts& hmat,
|
const common::HistogramCuts& hmat,
|
||||||
dh::BulkAllocator& ba);
|
dh::BulkAllocator* ba);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||||
|
|||||||
@ -85,7 +85,7 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
|
|||||||
monitor_.StopCuda("Quantiles");
|
monitor_.StopCuda("Quantiles");
|
||||||
|
|
||||||
monitor_.StartCuda("CreateEllpackInfo");
|
monitor_.StartCuda("CreateEllpackInfo");
|
||||||
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, ba_);
|
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, &ba_);
|
||||||
monitor_.StopCuda("CreateEllpackInfo");
|
monitor_.StopCuda("CreateEllpackInfo");
|
||||||
|
|
||||||
monitor_.StartCuda("WriteEllpackPages");
|
monitor_.StartCuda("WriteEllpackPages");
|
||||||
|
|||||||
@ -101,7 +101,7 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
|||||||
HostDeviceVector<size_t>* out_offset,
|
HostDeviceVector<size_t>* out_offset,
|
||||||
dh::caching_device_vector<int32_t>* out_d_flag,
|
dh::caching_device_vector<int32_t>* out_d_flag,
|
||||||
uint32_t* out_n_rows) {
|
uint32_t* out_n_rows) {
|
||||||
int32_t constexpr kThreads = 256;
|
uint32_t constexpr kThreads = 256;
|
||||||
auto const& j_column = j_columns[column_id];
|
auto const& j_column = j_columns[column_id];
|
||||||
auto const& column_obj = get<Object const>(j_column);
|
auto const& column_obj = get<Object const>(j_column);
|
||||||
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
||||||
@ -123,8 +123,9 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
|||||||
|
|
||||||
common::Span<size_t> s_offsets = out_offset->DeviceSpan();
|
common::Span<size_t> s_offsets = out_offset->DeviceSpan();
|
||||||
|
|
||||||
int32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
|
uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||||
CountValidKernel<T><<<kBlocks, kThreads>>>(
|
dh::LaunchKernel {kBlocks, kThreads} (
|
||||||
|
CountValidKernel<T>,
|
||||||
foreign_column,
|
foreign_column,
|
||||||
has_missing, missing,
|
has_missing, missing,
|
||||||
out_d_flag->data().get(), s_offsets);
|
out_d_flag->data().get(), s_offsets);
|
||||||
@ -135,13 +136,15 @@ template <typename T>
|
|||||||
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
|
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
|
||||||
bool has_missing, float missing,
|
bool has_missing, float missing,
|
||||||
dh::device_vector<size_t>* tmp_offset, common::Span<Entry> s_data) {
|
dh::device_vector<size_t>* tmp_offset, common::Span<Entry> s_data) {
|
||||||
int32_t constexpr kThreads = 256;
|
uint32_t constexpr kThreads = 256;
|
||||||
auto const& j_column = j_columns[column_id];
|
auto const& j_column = j_columns[column_id];
|
||||||
auto const& column_obj = get<Object const>(j_column);
|
auto const& column_obj = get<Object const>(j_column);
|
||||||
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
||||||
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||||
CreateCSRKernel<T><<<kBlocks, kThreads>>>(foreign_column, column_id, has_missing, missing,
|
dh::LaunchKernel {kBlocks, kThreads} (
|
||||||
dh::ToSpan(*tmp_offset), s_data);
|
CreateCSRKernel<T>,
|
||||||
|
foreign_column, column_id, has_missing, missing,
|
||||||
|
dh::ToSpan(*tmp_offset), s_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
|
void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
|
||||||
|
|||||||
@ -246,6 +246,14 @@ class GBTree : public GradientBooster {
|
|||||||
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
|
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
|
||||||
DMatrix* f_dmat = nullptr) const {
|
DMatrix* f_dmat = nullptr) const {
|
||||||
CHECK(configured_);
|
CHECK(configured_);
|
||||||
|
auto on_device = f_dmat && (*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
// Use GPU Predictor if data is already on device.
|
||||||
|
if (!specified_predictor_ && on_device) {
|
||||||
|
CHECK(gpu_predictor_);
|
||||||
|
return gpu_predictor_;
|
||||||
|
}
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
// GPU_Hist by default has prediction cache calculated from quantile values, so GPU
|
// GPU_Hist by default has prediction cache calculated from quantile values, so GPU
|
||||||
// Predictor is not used for training dataset. But when XGBoost performs continue
|
// Predictor is not used for training dataset. But when XGBoost performs continue
|
||||||
// training with an existing model, the prediction cache is not availbale and number
|
// training with an existing model, the prediction cache is not availbale and number
|
||||||
@ -256,7 +264,7 @@ class GBTree : public GradientBooster {
|
|||||||
(model_.param.num_trees != 0) &&
|
(model_.param.num_trees != 0) &&
|
||||||
// FIXME(trivialfis): Implement a better method for testing whether data is on
|
// FIXME(trivialfis): Implement a better method for testing whether data is on
|
||||||
// device after DMatrix refactoring is done.
|
// device after DMatrix refactoring is done.
|
||||||
(f_dmat && !((*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead()))) {
|
!on_device) {
|
||||||
return cpu_predictor_;
|
return cpu_predictor_;
|
||||||
}
|
}
|
||||||
if (tparam_.predictor == "cpu_predictor") {
|
if (tparam_.predictor == "cpu_predictor") {
|
||||||
|
|||||||
@ -630,7 +630,7 @@ class LearnerImpl : public Learner {
|
|||||||
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||||
num_feature = std::max(num_feature, static_cast<unsigned>(num_col));
|
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||||
}
|
}
|
||||||
// run allreduce on num_feature to find the maximum value
|
// run allreduce on num_feature to find the maximum value
|
||||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");
|
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
* \file elementwise_metric.cc
|
* \file elementwise_metric.cc
|
||||||
* \brief evaluation metrics for elementwise binary or regression.
|
* \brief evaluation metrics for elementwise binary or regression.
|
||||||
* \author Kailong Chen, Tianqi Chen
|
* \author Kailong Chen, Tianqi Chen
|
||||||
|
*
|
||||||
|
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
|
||||||
*/
|
*/
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
@ -142,7 +144,7 @@ struct EvalRowRMSE {
|
|||||||
return diff * diff;
|
return diff * diff;
|
||||||
}
|
}
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return std::sqrt(esum / wsum);
|
return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -150,12 +152,13 @@ struct EvalRowRMSLE {
|
|||||||
char const* Name() const {
|
char const* Name() const {
|
||||||
return "rmsle";
|
return "rmsle";
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||||
bst_float diff = std::log1p(label) - std::log1p(pred);
|
bst_float diff = std::log1p(label) - std::log1p(pred);
|
||||||
return diff * diff;
|
return diff * diff;
|
||||||
}
|
}
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return std::sqrt(esum / wsum);
|
return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -168,7 +171,7 @@ struct EvalRowMAE {
|
|||||||
return std::abs(label - pred);
|
return std::abs(label - pred);
|
||||||
}
|
}
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return esum / wsum;
|
return wsum == 0 ? esum : esum / wsum;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -190,7 +193,7 @@ struct EvalRowLogLoss {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return esum / wsum;
|
return wsum == 0 ? esum : esum / wsum;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -225,7 +228,7 @@ struct EvalError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return esum / wsum;
|
return wsum == 0 ? esum : esum / wsum;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -245,7 +248,7 @@ struct EvalPoissonNegLogLik {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return esum / wsum;
|
return wsum == 0 ? esum : esum / wsum;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -278,7 +281,7 @@ struct EvalGammaNLogLik {
|
|||||||
return -((y * theta - b) / a + c);
|
return -((y * theta - b) / a + c);
|
||||||
}
|
}
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return esum / wsum;
|
return wsum == 0 ? esum : esum / wsum;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -304,7 +307,7 @@ struct EvalTweedieNLogLik {
|
|||||||
return -a + b;
|
return -a + b;
|
||||||
}
|
}
|
||||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||||
return esum / wsum;
|
return wsum == 0 ? esum : esum / wsum;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -323,7 +326,9 @@ struct EvalEWiseBase : public Metric {
|
|||||||
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
bool distributed) override {
|
bool distributed) override {
|
||||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
if (info.labels_.Size() == 0) {
|
||||||
|
LOG(WARNING) << "label set is empty";
|
||||||
|
}
|
||||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
<< "label and prediction size not match, "
|
<< "label and prediction size not match, "
|
||||||
<< "hint: use merror or mlogloss for multi-class classification";
|
<< "hint: use merror or mlogloss for multi-class classification";
|
||||||
@ -333,6 +338,7 @@ struct EvalEWiseBase : public Metric {
|
|||||||
reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds);
|
reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds);
|
||||||
|
|
||||||
double dat[2] { result.Residue(), result.Weights() };
|
double dat[2] { result.Residue(), result.Weights() };
|
||||||
|
|
||||||
if (distributed) {
|
if (distributed) {
|
||||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,7 +54,9 @@ class RegLossObj : public ObjFunction {
|
|||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
int iter,
|
int iter,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
if (info.labels_.Size() == 0U) {
|
||||||
|
LOG(WARNING) << "Label set is empty.";
|
||||||
|
}
|
||||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
<< "labels are not correctly provided"
|
<< "labels are not correctly provided"
|
||||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
|
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
|
||||||
|
|||||||
@ -60,6 +60,9 @@ class CPUPredictor : public Predictor {
|
|||||||
constexpr int kUnroll = 8;
|
constexpr int kUnroll = 8;
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
const bst_omp_uint rest = nsize % kUnroll;
|
const bst_omp_uint rest = nsize % kUnroll;
|
||||||
|
// Pull to host before entering omp block, as this is not thread safe.
|
||||||
|
batch.data.HostVector();
|
||||||
|
batch.offset.HostVector();
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
|
|||||||
@ -225,12 +225,12 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
HostDeviceVector<bst_float>* predictions,
|
HostDeviceVector<bst_float>* predictions,
|
||||||
size_t batch_offset) {
|
size_t batch_offset) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
const int BLOCK_THREADS = 128;
|
const uint32_t BLOCK_THREADS = 128;
|
||||||
size_t num_rows = batch.Size();
|
size_t num_rows = batch.Size();
|
||||||
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
|
|
||||||
int shared_memory_bytes = static_cast<int>
|
auto shared_memory_bytes =
|
||||||
(sizeof(float) * num_features * BLOCK_THREADS);
|
static_cast<size_t>(sizeof(float) * num_features * BLOCK_THREADS);
|
||||||
bool use_shared = true;
|
bool use_shared = true;
|
||||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
||||||
shared_memory_bytes = 0;
|
shared_memory_bytes = 0;
|
||||||
@ -238,11 +238,12 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
|
|
||||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||||
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
PredictKernel<BLOCK_THREADS>,
|
||||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||||
entry_start, use_shared, this->num_group_);
|
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||||
|
entry_start, use_shared, this->num_group_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||||
|
|||||||
@ -165,10 +165,11 @@ __global__ void ClearBuffersKernel(
|
|||||||
void FeatureInteractionConstraint::ClearBuffers() {
|
void FeatureInteractionConstraint::ClearBuffers() {
|
||||||
CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size());
|
CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size());
|
||||||
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
|
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
|
||||||
int constexpr kBlockThreads = 256;
|
uint32_t constexpr kBlockThreads = 256;
|
||||||
const int n_grids = static_cast<int>(
|
auto const n_grids = static_cast<uint32_t>(
|
||||||
common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
|
common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
|
||||||
ClearBuffersKernel<<<n_grids, kBlockThreads>>>(
|
dh::LaunchKernel {n_grids, kBlockThreads} (
|
||||||
|
ClearBuffersKernel,
|
||||||
output_buffer_bits_, input_buffer_bits_);
|
output_buffer_bits_, input_buffer_bits_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,12 +223,14 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
|
|||||||
LBitField64 node_constraints = s_node_constraints_[nid];
|
LBitField64 node_constraints = s_node_constraints_[nid];
|
||||||
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
|
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
|
||||||
|
|
||||||
int constexpr kBlockThreads = 256;
|
uint32_t constexpr kBlockThreads = 256;
|
||||||
const int n_grids = static_cast<int>(
|
auto n_grids = static_cast<uint32_t>(
|
||||||
common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
|
common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
|
||||||
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_);
|
dh::LaunchKernel {n_grids, kBlockThreads} (
|
||||||
|
SetInputBufferKernel,
|
||||||
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>(
|
feature_list, input_buffer_bits_);
|
||||||
|
dh::LaunchKernel {n_grids, kBlockThreads} (
|
||||||
|
QueryFeatureListKernel,
|
||||||
node_constraints, input_buffer_bits_, output_buffer_bits_);
|
node_constraints, input_buffer_bits_, output_buffer_bits_);
|
||||||
|
|
||||||
thrust::counting_iterator<int32_t> begin(0);
|
thrust::counting_iterator<int32_t> begin(0);
|
||||||
@ -327,20 +330,20 @@ void FeatureInteractionConstraint::Split(
|
|||||||
dim3 const block3(16, 64, 1);
|
dim3 const block3(16, 64, 1);
|
||||||
dim3 const grid3(common::DivRoundUp(n_sets_, 16),
|
dim3 const grid3(common::DivRoundUp(n_sets_, 16),
|
||||||
common::DivRoundUp(s_fconstraints_.size(), 64));
|
common::DivRoundUp(s_fconstraints_.size(), 64));
|
||||||
RestoreFeatureListFromSetsKernel<<<grid3, block3>>>
|
dh::LaunchKernel {grid3, block3} (
|
||||||
(feature_buffer_,
|
RestoreFeatureListFromSetsKernel,
|
||||||
feature_id,
|
feature_buffer_, feature_id,
|
||||||
s_fconstraints_,
|
s_fconstraints_, s_fconstraints_ptr_,
|
||||||
s_fconstraints_ptr_,
|
s_sets_, s_sets_ptr_);
|
||||||
s_sets_,
|
|
||||||
s_sets_ptr_);
|
|
||||||
|
|
||||||
int constexpr kBlockThreads = 256;
|
uint32_t constexpr kBlockThreads = 256;
|
||||||
const int n_grids = static_cast<int>(common::DivRoundUp(node.Size(), kBlockThreads));
|
auto n_grids = static_cast<uint32_t>(common::DivRoundUp(node.Size(), kBlockThreads));
|
||||||
InteractionConstraintSplitKernel<<<n_grids, kBlockThreads>>>
|
|
||||||
(feature_buffer_,
|
dh::LaunchKernel {n_grids, kBlockThreads} (
|
||||||
feature_id,
|
InteractionConstraintSplitKernel,
|
||||||
node, left, right);
|
feature_buffer_,
|
||||||
|
feature_id,
|
||||||
|
node, left, right);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -603,12 +603,12 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// One block for each feature
|
// One block for each feature
|
||||||
int constexpr kBlockThreads = 256;
|
uint32_t constexpr kBlockThreads = 256;
|
||||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>
|
dh::LaunchKernel {uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]} (
|
||||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
EvaluateSplitKernel<kBlockThreads, GradientSumT>,
|
||||||
hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix,
|
hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix,
|
||||||
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
||||||
monotone_constraints);
|
monotone_constraints);
|
||||||
|
|
||||||
// Reduce over features to find best feature
|
// Reduce over features to find best feature
|
||||||
auto d_cub_memory =
|
auto d_cub_memory =
|
||||||
@ -638,14 +638,12 @@ struct GPUHistMakerDevice {
|
|||||||
use_shared_memory_histograms
|
use_shared_memory_histograms
|
||||||
? sizeof(GradientSumT) * page->matrix.BinCount()
|
? sizeof(GradientSumT) * page->matrix.BinCount()
|
||||||
: 0;
|
: 0;
|
||||||
const int items_per_thread = 8;
|
uint32_t items_per_thread = 8;
|
||||||
const int block_threads = 256;
|
uint32_t block_threads = 256;
|
||||||
const int grid_size = static_cast<int>(
|
auto grid_size = static_cast<uint32_t>(
|
||||||
common::DivRoundUp(n_elements, items_per_thread * block_threads));
|
common::DivRoundUp(n_elements, items_per_thread * block_threads));
|
||||||
if (grid_size <= 0) {
|
dh::LaunchKernel {grid_size, block_threads, smem_size} (
|
||||||
return;
|
SharedMemHistKernel<GradientSumT>,
|
||||||
}
|
|
||||||
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
|
|
||||||
page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
|
||||||
use_shared_memory_histograms);
|
use_shared_memory_histograms);
|
||||||
}
|
}
|
||||||
@ -886,6 +884,7 @@ struct GPUHistMakerDevice {
|
|||||||
monitor.StartCuda("InitRoot");
|
monitor.StartCuda("InitRoot");
|
||||||
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
|
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
|
||||||
monitor.StopCuda("InitRoot");
|
monitor.StopCuda("InitRoot");
|
||||||
|
|
||||||
auto timestamp = qexpand->size();
|
auto timestamp = qexpand->size();
|
||||||
auto num_leaves = 1;
|
auto num_leaves = 1;
|
||||||
|
|
||||||
@ -895,7 +894,6 @@ struct GPUHistMakerDevice {
|
|||||||
if (!candidate.IsValid(param, num_leaves)) {
|
if (!candidate.IsValid(param, num_leaves)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
this->ApplySplit(candidate, p_tree);
|
this->ApplySplit(candidate, p_tree);
|
||||||
|
|
||||||
num_leaves++;
|
num_leaves++;
|
||||||
@ -996,18 +994,22 @@ class GPUHistMakerSpecialised {
|
|||||||
try {
|
try {
|
||||||
for (xgboost::RegTree* tree : trees) {
|
for (xgboost::RegTree* tree : trees) {
|
||||||
this->UpdateTree(gpair, dmat, tree);
|
this->UpdateTree(gpair, dmat, tree);
|
||||||
|
|
||||||
|
if (hist_maker_param_.debug_synchronize) {
|
||||||
|
this->CheckTreesSynchronized(tree);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dh::safe_cuda(cudaGetLastError());
|
dh::safe_cuda(cudaGetLastError());
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
monitor_.StopCuda("Update");
|
monitor_.StopCuda("Update");
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitDataOnce(DMatrix* dmat) {
|
void InitDataOnce(DMatrix* dmat) {
|
||||||
info_ = &dmat->Info();
|
info_ = &dmat->Info();
|
||||||
|
|
||||||
reducer_.Init({device_});
|
reducer_.Init({device_});
|
||||||
|
|
||||||
// Synchronise the column sampling seed
|
// Synchronise the column sampling seed
|
||||||
@ -1048,20 +1050,18 @@ class GPUHistMakerSpecialised {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Only call this method for testing
|
// Only call this method for testing
|
||||||
void CheckTreesSynchronized(const std::vector<RegTree>& local_trees) const {
|
void CheckTreesSynchronized(RegTree* local_tree) const {
|
||||||
std::string s_model;
|
std::string s_model;
|
||||||
common::MemoryBufferStream fs(&s_model);
|
common::MemoryBufferStream fs(&s_model);
|
||||||
int rank = rabit::GetRank();
|
int rank = rabit::GetRank();
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
local_trees.front().SaveModel(&fs);
|
local_tree->SaveModel(&fs);
|
||||||
}
|
}
|
||||||
fs.Seek(0);
|
fs.Seek(0);
|
||||||
rabit::Broadcast(&s_model, 0);
|
rabit::Broadcast(&s_model, 0);
|
||||||
RegTree reference_tree{};
|
RegTree reference_tree {}; // rank 0 tree
|
||||||
reference_tree.LoadModel(&fs);
|
reference_tree.LoadModel(&fs);
|
||||||
for (const auto& tree : local_trees) {
|
CHECK(*local_tree == reference_tree);
|
||||||
CHECK(tree == reference_tree);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ ENV PATH=/opt/python/bin:$PATH
|
|||||||
# Create new Conda environment with cuDF and dask
|
# Create new Conda environment with cuDF and dask
|
||||||
RUN \
|
RUN \
|
||||||
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
|
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
|
||||||
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask
|
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda
|
||||||
|
|
||||||
# Install other Python packages
|
# Install other Python packages
|
||||||
RUN \
|
RUN \
|
||||||
|
|||||||
@ -17,7 +17,8 @@ ENV PATH=/opt/python/bin:$PATH
|
|||||||
# Install Python packages
|
# Install Python packages
|
||||||
RUN \
|
RUN \
|
||||||
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
|
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
|
||||||
pip install "dask[complete]"
|
pip install "dask[complete]" && \
|
||||||
|
conda install -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda dask-cuda
|
||||||
|
|
||||||
ENV GOSU_VERSION 1.10
|
ENV GOSU_VERSION 1.10
|
||||||
|
|
||||||
|
|||||||
@ -21,18 +21,12 @@ RUN \
|
|||||||
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
|
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
|
||||||
RUN \
|
RUN \
|
||||||
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
|
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
|
||||||
if [ "${CUDA_SHORT}" != "10.0" ] && [ "${CUDA_SHORT}" != "10.1" ]; then \
|
export NCCL_VERSION=2.4.8-1 && \
|
||||||
wget https://developer.download.nvidia.com/compute/redist/nccl/v2.2/nccl_2.2.13-1%2Bcuda${CUDA_SHORT}_x86_64.txz && \
|
|
||||||
tar xf "nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64.txz" && \
|
|
||||||
cp nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64/include/nccl.h /usr/include && \
|
|
||||||
cp nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64/lib/* /usr/lib && \
|
|
||||||
rm -f nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64.txz && \
|
|
||||||
rm -r nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64; else \
|
|
||||||
wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||||
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||||
yum -y update && \
|
yum -y update && \
|
||||||
yum install -y libnccl-2.4.2-1+cuda${CUDA_SHORT} libnccl-devel-2.4.2-1+cuda${CUDA_SHORT} libnccl-static-2.4.2-1+cuda${CUDA_SHORT} && \
|
yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \
|
||||||
rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm; fi
|
rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm;
|
||||||
|
|
||||||
ENV PATH=/opt/python/bin:$PATH
|
ENV PATH=/opt/python/bin:$PATH
|
||||||
ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc
|
ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc
|
||||||
|
|||||||
@ -33,11 +33,13 @@ case "$suite" in
|
|||||||
pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu
|
pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu
|
||||||
cd tests/distributed
|
cd tests/distributed
|
||||||
./runtests-gpu.sh
|
./runtests-gpu.sh
|
||||||
|
cd -
|
||||||
|
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py
|
||||||
;;
|
;;
|
||||||
|
|
||||||
cudf)
|
cudf)
|
||||||
source activate cudf_test
|
source activate cudf_test
|
||||||
python -m pytest -v -s --fulltrace tests/python-gpu/test_from_columnar.py tests/python-gpu/test_gpu_with_dask.py
|
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py
|
||||||
;;
|
;;
|
||||||
|
|
||||||
cpu)
|
cpu)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ if (USE_CUDA)
|
|||||||
# OpenMP is mandatory for CUDA
|
# OpenMP is mandatory for CUDA
|
||||||
find_package(OpenMP REQUIRED)
|
find_package(OpenMP REQUIRED)
|
||||||
target_include_directories(testxgboost PRIVATE
|
target_include_directories(testxgboost PRIVATE
|
||||||
${PROJECT_SOURCE_DIR}/cub/)
|
${xgboost_SOURCE_DIR}/cub/)
|
||||||
target_compile_options(testxgboost PRIVATE
|
target_compile_options(testxgboost PRIVATE
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
|
||||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
|
||||||
@ -48,9 +48,9 @@ endif (USE_CUDA)
|
|||||||
target_include_directories(testxgboost
|
target_include_directories(testxgboost
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${GTEST_INCLUDE_DIRS}
|
${GTEST_INCLUDE_DIRS}
|
||||||
${PROJECT_SOURCE_DIR}/include
|
${xgboost_SOURCE_DIR}/include
|
||||||
${PROJECT_SOURCE_DIR}/dmlc-core/include
|
${xgboost_SOURCE_DIR}/dmlc-core/include
|
||||||
${PROJECT_SOURCE_DIR}/rabit/include)
|
${xgboost_SOURCE_DIR}/rabit/include)
|
||||||
set_target_properties(
|
set_target_properties(
|
||||||
testxgboost PROPERTIES
|
testxgboost PROPERTIES
|
||||||
CXX_STANDARD 11
|
CXX_STANDARD 11
|
||||||
@ -67,7 +67,7 @@ target_compile_definitions(testxgboost PRIVATE ${XGBOOST_DEFINITIONS})
|
|||||||
if (USE_OPENMP)
|
if (USE_OPENMP)
|
||||||
target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>)
|
target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>)
|
||||||
endif (USE_OPENMP)
|
endif (USE_OPENMP)
|
||||||
set_output_directory(testxgboost ${PROJECT_BINARY_DIR})
|
set_output_directory(testxgboost ${xgboost_BINARY_DIR})
|
||||||
|
|
||||||
# This grouping organises source files nicely in visual studio
|
# This grouping organises source files nicely in visual studio
|
||||||
auto_source_group("${TEST_SOURCES}")
|
auto_source_group("${TEST_SOURCES}")
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def run_test(name, params_fun):
|
def run_test(name, params_fun):
|
||||||
@ -48,6 +49,9 @@ def run_test(name, params_fun):
|
|||||||
|
|
||||||
xgb.rabit.finalize()
|
xgb.rabit.finalize()
|
||||||
|
|
||||||
|
if os.path.exists(model_name):
|
||||||
|
os.remove(model_name)
|
||||||
|
|
||||||
|
|
||||||
base_params = {
|
base_params = {
|
||||||
'tree_method': 'gpu_hist',
|
'tree_method': 'gpu_hist',
|
||||||
@ -81,7 +85,5 @@ def wrap_rf(params_fun):
|
|||||||
|
|
||||||
params_rf_1x4 = wrap_rf(params_basic_1x4)
|
params_rf_1x4 = wrap_rf(params_basic_1x4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
test_name = sys.argv[1]
|
test_name = sys.argv[1]
|
||||||
run_test(test_name, globals()['params_%s' % test_name])
|
run_test(test_name, globals()['params_%s' % test_name])
|
||||||
|
|||||||
@ -6,7 +6,7 @@ export DMLC_SUBMIT_CLUSTER=local
|
|||||||
submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit"
|
submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit"
|
||||||
|
|
||||||
echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
||||||
$submit --num-workers=4 python distributed_gpu.py basic_1x4 || exit 1
|
$submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py basic_1x4 || exit 1
|
||||||
|
|
||||||
echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
||||||
$submit --num-workers=4 python distributed_gpu.py rf_1x4 || exit 1
|
$submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py rf_1x4 || exit 1
|
||||||
|
|||||||
3
tests/pytest.ini
Normal file
3
tests/pytest.ini
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
mgpu: Mark a test that requires multiple GPUs to run.
|
||||||
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
import xgboost
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
from regression_test_utilities import run_suite, parameter_combinations, \
|
from regression_test_utilities import run_suite, parameter_combinations, \
|
||||||
@ -21,7 +22,8 @@ datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
|
|||||||
|
|
||||||
class TestGPU(unittest.TestCase):
|
class TestGPU(unittest.TestCase):
|
||||||
def test_gpu_hist(self):
|
def test_gpu_hist(self):
|
||||||
test_param = parameter_combinations({'gpu_id': [0], 'max_depth': [2, 8],
|
test_param = parameter_combinations({'gpu_id': [0],
|
||||||
|
'max_depth': [2, 8],
|
||||||
'max_leaves': [255, 4],
|
'max_leaves': [255, 4],
|
||||||
'max_bin': [2, 256],
|
'max_bin': [2, 256],
|
||||||
'grow_policy': ['lossguide']})
|
'grow_policy': ['lossguide']})
|
||||||
@ -36,6 +38,31 @@ class TestGPU(unittest.TestCase):
|
|||||||
cpu_results = run_suite(param, select_datasets=datasets)
|
cpu_results = run_suite(param, select_datasets=datasets)
|
||||||
assert_gpu_results(cpu_results, gpu_results)
|
assert_gpu_results(cpu_results, gpu_results)
|
||||||
|
|
||||||
|
def test_with_empty_dmatrix(self):
|
||||||
|
# FIXME(trivialfis): This should be done with all updaters
|
||||||
|
kRows = 0
|
||||||
|
kCols = 100
|
||||||
|
|
||||||
|
X = np.empty((kRows, kCols))
|
||||||
|
y = np.empty((kRows))
|
||||||
|
|
||||||
|
dtrain = xgboost.DMatrix(X, y)
|
||||||
|
|
||||||
|
bst = xgboost.train({'verbosity': 2,
|
||||||
|
'tree_method': 'gpu_hist',
|
||||||
|
'gpu_id': 0},
|
||||||
|
dtrain,
|
||||||
|
verbose_eval=True,
|
||||||
|
num_boost_round=6,
|
||||||
|
evals=[(dtrain, 'Train')])
|
||||||
|
|
||||||
|
kRows = 100
|
||||||
|
X = np.random.randn(kRows, kCols)
|
||||||
|
|
||||||
|
dtest = xgboost.DMatrix(X)
|
||||||
|
predictions = bst.predict(dtest)
|
||||||
|
np.testing.assert_allclose(predictions, 0.5, 1e-6)
|
||||||
|
|
||||||
@pytest.mark.mgpu
|
@pytest.mark.mgpu
|
||||||
def test_specified_gpu_id_gpu_update(self):
|
def test_specified_gpu_id_gpu_update(self):
|
||||||
variable_param = {'gpu_id': [1],
|
variable_param = {'gpu_id': [1],
|
||||||
|
|||||||
@ -1,45 +1,94 @@
|
|||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from distributed.utils_test import client, loop, cluster_fixture
|
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
from xgboost import dask as dxgb
|
from xgboost import dask as dxgb
|
||||||
|
from dask_cuda import LocalCUDACluster
|
||||||
|
from dask.distributed import Client
|
||||||
import cudf
|
import cudf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
client = None
|
|
||||||
loop = None
|
|
||||||
cluster_fixture = None
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
from test_with_dask import generate_array
|
from test_with_dask import generate_array # noqa
|
||||||
import testing as tm
|
import testing as tm # noqa
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_dask())
|
class TestDistributedGPU(unittest.TestCase):
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_dask_dataframe(client):
|
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||||
X, y = generate_array()
|
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||||
|
def test_dask_dataframe(self):
|
||||||
|
with LocalCUDACluster() as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
|
||||||
X = dd.from_dask_array(X)
|
X = dd.from_dask_array(X)
|
||||||
y = dd.from_dask_array(y)
|
y = dd.from_dask_array(y)
|
||||||
|
|
||||||
X = X.map_partitions(cudf.from_pandas)
|
X = X.map_partitions(cudf.from_pandas)
|
||||||
y = y.map_partitions(cudf.from_pandas)
|
y = y.map_partitions(cudf.from_pandas)
|
||||||
|
|
||||||
dtrain = dxgb.DaskDMatrix(client, X, y)
|
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||||
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
|
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
evals=[(dtrain, 'X')],
|
evals=[(dtrain, 'X')],
|
||||||
num_boost_round=2)
|
num_boost_round=2)
|
||||||
|
|
||||||
assert isinstance(out['booster'], dxgb.Booster)
|
assert isinstance(out['booster'], dxgb.Booster)
|
||||||
assert len(out['history']['X']['rmse']) == 2
|
assert len(out['history']['X']['rmse']) == 2
|
||||||
|
|
||||||
predictions = dxgb.predict(out, dtrain)
|
# FIXME(trivialfis): Re-enable this after #5003 is fixed
|
||||||
predictions = predictions.compute()
|
# predictions = dxgb.predict(client, out, dtrain).compute()
|
||||||
|
# assert isinstance(predictions, np.ndarray)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
|
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||||
|
@pytest.mark.mgpu
|
||||||
|
def test_empty_dmatrix(self):
|
||||||
|
|
||||||
|
def _check_outputs(out, predictions):
|
||||||
|
assert isinstance(out['booster'], dxgb.Booster)
|
||||||
|
assert len(out['history']['validation']['rmse']) == 2
|
||||||
|
assert isinstance(predictions, np.ndarray)
|
||||||
|
assert predictions.shape[0] == 1
|
||||||
|
|
||||||
|
parameters = {'tree_method': 'gpu_hist', 'verbosity': 3,
|
||||||
|
'debug_synchronize': True}
|
||||||
|
|
||||||
|
with LocalCUDACluster() as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
kRows, kCols = 1, 97
|
||||||
|
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||||
|
y = dd.from_array(np.random.rand(kRows))
|
||||||
|
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
out = dxgb.train(client, parameters,
|
||||||
|
dtrain=dtrain,
|
||||||
|
evals=[(dtrain, 'validation')],
|
||||||
|
num_boost_round=2)
|
||||||
|
predictions = dxgb.predict(client=client, model=out,
|
||||||
|
data=dtrain).compute()
|
||||||
|
_check_outputs(out, predictions)
|
||||||
|
|
||||||
|
# train has more rows than evals
|
||||||
|
valid = dtrain
|
||||||
|
kRows += 1
|
||||||
|
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||||
|
y = dd.from_array(np.random.rand(kRows))
|
||||||
|
dtrain = dxgb.DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
out = dxgb.train(client, parameters,
|
||||||
|
dtrain=dtrain,
|
||||||
|
evals=[(valid, 'validation')],
|
||||||
|
num_boost_round=2)
|
||||||
|
predictions = dxgb.predict(client=client, model=out,
|
||||||
|
data=valid).compute()
|
||||||
|
_check_outputs(out, predictions)
|
||||||
|
|||||||
@ -67,7 +67,8 @@ def get_weights_regression(min_weight, max_weight):
|
|||||||
n = 10000
|
n = 10000
|
||||||
sparsity = 0.25
|
sparsity = 0.25
|
||||||
X, y = datasets.make_regression(n, random_state=rng)
|
X, y = datasets.make_regression(n, random_state=rng)
|
||||||
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X])
|
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x
|
||||||
|
for x in x_row] for x_row in X])
|
||||||
w = np.array([rng.uniform(min_weight, max_weight) for i in range(n)])
|
w = np.array([rng.uniform(min_weight, max_weight) for i in range(n)])
|
||||||
return X, y, w
|
return X, y, w
|
||||||
|
|
||||||
|
|||||||
@ -34,6 +34,15 @@ def no_matplotlib():
|
|||||||
'reason': reason}
|
'reason': reason}
|
||||||
|
|
||||||
|
|
||||||
|
def no_dask_cuda():
|
||||||
|
reason = 'dask_cuda is not installed.'
|
||||||
|
try:
|
||||||
|
import dask_cuda as _ # noqa
|
||||||
|
return {'condition': False, 'reason': reason}
|
||||||
|
except ImportError:
|
||||||
|
return {'condition': True, 'reason': reason}
|
||||||
|
|
||||||
|
|
||||||
def no_cudf():
|
def no_cudf():
|
||||||
return {'condition': not CUDF_INSTALLED,
|
return {'condition': not CUDF_INSTALLED,
|
||||||
'reason': 'CUDF is not installed'}
|
'reason': 'CUDF is not installed'}
|
||||||
|
|||||||
@ -34,6 +34,13 @@ fi
|
|||||||
|
|
||||||
if [ ${TASK} == "cmake_test" ]; then
|
if [ ${TASK} == "cmake_test" ]; then
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
if grep -n -R '<<<.*>>>\(.*\)' src include | grep --invert "NOLINT"; then
|
||||||
|
echo 'Do not use raw CUDA execution configuration syntax with <<<blocks, threads>>>.' \
|
||||||
|
'try `dh::LaunchKernel`'
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
# Build/test
|
# Build/test
|
||||||
rm -rf build
|
rm -rf build
|
||||||
mkdir build && cd build
|
mkdir build && cd build
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user