Run training with empty DMatrix. (#4990)

This makes GPU Hist robust in distributed environment as some workers might not
be associated with any data in either training or evaluation.

* Disable rabit mock test for now: See #5012 .

* Disable dask-cudf test at prediction for now: See #5003

* Launch dask job for all workers despite they might not have any data.
* Check 0 rows in elementwise evaluation metrics.

   Using AUC and AUC-PR still throws an error.  See #4663 for a robust fix.

* Add tests for edge cases.
* Add `LaunchKernel` wrapper handling zero sized grid.
* Move some parts of allreducer into a cu file.
* Don't validate feature names when the booster is empty.

* Sync number of columns in DMatrix.

  As num_feature is required to be the same across all workers in data split
  mode.

* Filtering in dask interface now by default syncs all booster that's not
empty, instead of using rank 0.

* Fix Jenkins' GPU tests.

* Install dask-cuda from source in Jenkins' test.

  Now all tests are actually running.

* Restore GPU Hist tree synchronization test.

* Check UUID of running devices.

  The check is only performed on CUDA version >= 10.x, as 9.x doesn't have UUID field.

* Fix CMake policy and project variables.

  Use xgboost_SOURCE_DIR uniformly, add policy for CMake >= 3.13.

* Fix copying data to CPU

* Fix race condition in cpu predictor.

* Fix duplicated DMatrix construction.

* Don't download extra nccl in CI script.
This commit is contained in:
Jiaming Yuan 2019-11-06 16:13:13 +08:00 committed by GitHub
parent 807a244517
commit 7663de956c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 603 additions and 272 deletions

View File

@ -1,9 +1,13 @@
cmake_minimum_required(VERSION 3.3) cmake_minimum_required(VERSION 3.3)
project(xgboost LANGUAGES CXX C VERSION 1.0.0) project(xgboost LANGUAGES CXX C VERSION 1.0.0)
include(cmake/Utils.cmake) include(cmake/Utils.cmake)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules")
cmake_policy(SET CMP0022 NEW) cmake_policy(SET CMP0022 NEW)
if ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13))
cmake_policy(SET CMP0077 NEW)
endif ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13))
message(STATUS "CMake version ${CMAKE_VERSION}") message(STATUS "CMake version ${CMAKE_VERSION}")
if (MSVC) if (MSVC)
cmake_minimum_required(VERSION 3.11) cmake_minimum_required(VERSION 3.11)
@ -84,7 +88,7 @@ endif (USE_CUDA)
# dmlc-core # dmlc-core
msvc_use_static_runtime() msvc_use_static_runtime()
add_subdirectory(${PROJECT_SOURCE_DIR}/dmlc-core) add_subdirectory(${xgboost_SOURCE_DIR}/dmlc-core)
set_target_properties(dmlc PROPERTIES set_target_properties(dmlc PROPERTIES
CXX_STANDARD 11 CXX_STANDARD 11
CXX_STANDARD_REQUIRED ON CXX_STANDARD_REQUIRED ON
@ -105,7 +109,7 @@ endif(RABIT_MOCK)
# Exports some R specific definitions and objects # Exports some R specific definitions and objects
if (R_LIB) if (R_LIB)
add_subdirectory(${PROJECT_SOURCE_DIR}/R-package) add_subdirectory(${xgboost_SOURCE_DIR}/R-package)
endif (R_LIB) endif (R_LIB)
# core xgboost # core xgboost
@ -123,22 +127,23 @@ target_link_libraries(xgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE})
# This creates its own shared library `xgboost4j'. # This creates its own shared library `xgboost4j'.
if (JVM_BINDINGS) if (JVM_BINDINGS)
add_subdirectory(${PROJECT_SOURCE_DIR}/jvm-packages) add_subdirectory(${xgboost_SOURCE_DIR}/jvm-packages)
endif (JVM_BINDINGS) endif (JVM_BINDINGS)
#-- End shared library #-- End shared library
#-- CLI for xgboost #-- CLI for xgboost
add_executable(runxgboost ${PROJECT_SOURCE_DIR}/src/cli_main.cc ${XGBOOST_OBJ_SOURCES}) add_executable(runxgboost ${xgboost_SOURCE_DIR}/src/cli_main.cc ${XGBOOST_OBJ_SOURCES})
# For cli_main.cc only # For cli_main.cc only
if (USE_OPENMP) if (USE_OPENMP)
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
target_compile_options(runxgboost PRIVATE ${OpenMP_CXX_FLAGS}) target_compile_options(runxgboost PRIVATE ${OpenMP_CXX_FLAGS})
endif (USE_OPENMP) endif (USE_OPENMP)
target_include_directories(runxgboost target_include_directories(runxgboost
PRIVATE PRIVATE
${PROJECT_SOURCE_DIR}/include ${xgboost_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dmlc-core/include ${xgboost_SOURCE_DIR}/dmlc-core/include
${PROJECT_SOURCE_DIR}/rabit/include) ${xgboost_SOURCE_DIR}/rabit/include)
target_link_libraries(runxgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE}) target_link_libraries(runxgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE})
set_target_properties( set_target_properties(
runxgboost PROPERTIES runxgboost PROPERTIES
@ -147,8 +152,8 @@ set_target_properties(
CXX_STANDARD_REQUIRED ON) CXX_STANDARD_REQUIRED ON)
#-- End CLI for xgboost #-- End CLI for xgboost
set_output_directory(runxgboost ${PROJECT_SOURCE_DIR}) set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
set_output_directory(xgboost ${PROJECT_SOURCE_DIR}/lib) set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
# Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names # Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
add_dependencies(xgboost runxgboost) add_dependencies(xgboost runxgboost)
@ -205,21 +210,21 @@ install(
if (GOOGLE_TEST) if (GOOGLE_TEST)
enable_testing() enable_testing()
# Unittests. # Unittests.
add_subdirectory(${PROJECT_SOURCE_DIR}/tests/cpp) add_subdirectory(${xgboost_SOURCE_DIR}/tests/cpp)
add_test( add_test(
NAME TestXGBoostLib NAME TestXGBoostLib
COMMAND testxgboost COMMAND testxgboost
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) WORKING_DIRECTORY ${xgboost_BINARY_DIR})
# CLI tests # CLI tests
configure_file( configure_file(
${PROJECT_SOURCE_DIR}/tests/cli/machine.conf.in ${xgboost_SOURCE_DIR}/tests/cli/machine.conf.in
${PROJECT_BINARY_DIR}/tests/cli/machine.conf ${xgboost_BINARY_DIR}/tests/cli/machine.conf
@ONLY) @ONLY)
add_test( add_test(
NAME TestXGBoostCLI NAME TestXGBoostCLI
COMMAND runxgboost ${PROJECT_BINARY_DIR}/tests/cli/machine.conf COMMAND runxgboost ${xgboost_BINARY_DIR}/tests/cli/machine.conf
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) WORKING_DIRECTORY ${xgboost_BINARY_DIR})
set_tests_properties(TestXGBoostCLI set_tests_properties(TestXGBoostCLI
PROPERTIES PROPERTIES
PASS_REGULAR_EXPRESSION ".*test-rmse:0.087.*") PASS_REGULAR_EXPRESSION ".*test-rmse:0.087.*")

1
Jenkinsfile vendored
View File

@ -83,7 +83,6 @@ pipeline {
'test-python-gpu-cuda10.0': { TestPythonGPU(cuda_version: '10.0') }, 'test-python-gpu-cuda10.0': { TestPythonGPU(cuda_version: '10.0') },
'test-python-gpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1') }, 'test-python-gpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1') },
'test-python-mgpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1', multi_gpu: true) }, 'test-python-mgpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1', multi_gpu: true) },
'test-cpp-rabit': {TestCppRabit()},
'test-cpp-gpu': { TestCppGPU(cuda_version: '10.1') }, 'test-cpp-gpu': { TestCppGPU(cuda_version: '10.1') },
'test-cpp-mgpu': { TestCppGPU(cuda_version: '10.1', multi_gpu: true) }, 'test-cpp-mgpu': { TestCppGPU(cuda_version: '10.1', multi_gpu: true) },
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '2.4.3') }, 'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '2.4.3') },

View File

@ -6,7 +6,7 @@ function (run_doxygen)
endif (NOT DOXYGEN_DOT_FOUND) endif (NOT DOXYGEN_DOT_FOUND)
configure_file( configure_file(
${PROJECT_SOURCE_DIR}/doc/Doxyfile.in ${xgboost_SOURCE_DIR}/doc/Doxyfile.in
${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY) ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY)
add_custom_target( doc_doxygen ALL add_custom_target( doc_doxygen ALL
COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile

View File

@ -111,7 +111,7 @@ DESTINATION \"${build_dir}/bak\")")
install(CODE "file(REMOVE_RECURSE \"${build_dir}/R-package\")") install(CODE "file(REMOVE_RECURSE \"${build_dir}/R-package\")")
install( install(
DIRECTORY "${PROJECT_SOURCE_DIR}/R-package" DIRECTORY "${xgboost_SOURCE_DIR}/R-package"
DESTINATION "${build_dir}" DESTINATION "${build_dir}"
REGEX "src/*" EXCLUDE REGEX "src/*" EXCLUDE
REGEX "R-package/configure" EXCLUDE REGEX "R-package/configure" EXCLUDE

View File

@ -5,6 +5,5 @@ function (write_version)
${xgboost_SOURCE_DIR}/include/xgboost/version_config.h @ONLY) ${xgboost_SOURCE_DIR}/include/xgboost/version_config.h @ONLY)
configure_file( configure_file(
${xgboost_SOURCE_DIR}/cmake/Python_version.in ${xgboost_SOURCE_DIR}/cmake/Python_version.in
${xgboost_SOURCE_DIR}/python-package/xgboost/VERSION ${xgboost_SOURCE_DIR}/python-package/xgboost/VERSION)
)
endfunction (write_version) endfunction (write_version)

View File

@ -0,0 +1,23 @@
if (NVML_LIBRARY)
unset(NVML_LIBRARY CACHE)
endif(NVML_LIBRARY)
set(NVML_LIB_NAME nvml)
find_path(NVML_INCLUDE_DIR
NAMES nvml.h
PATHS ${CUDA_HOME}/include ${CUDA_INCLUDE} /usr/local/cuda/include)
find_library(NVML_LIBRARY
NAMES nvidia-ml)
message(STATUS "Using nvml library: ${NVML_LIBRARY}")
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NVML DEFAULT_MSG
NVML_INCLUDE_DIR NVML_LIBRARY)
mark_as_advanced(
NVML_INCLUDE_DIR
NVML_LIBRARY
)

View File

@ -513,7 +513,7 @@ class DMatrix(object):
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
self._init_from_csr(csr) self._init_from_csr(csr)
except: except Exception:
raise TypeError('can not initialize DMatrix from' raise TypeError('can not initialize DMatrix from'
' {}'.format(type(data).__name__)) ' {}'.format(type(data).__name__))
@ -577,9 +577,9 @@ class DMatrix(object):
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ', raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
mat.shape) mat.shape)
# flatten the array by rows and ensure it is float32. # flatten the array by rows and ensure it is float32. we try to avoid
# we try to avoid data copies if possible (reshape returns a view when possible # data copies if possible (reshape returns a view when possible and we
# and we explicitly tell np.array to try and avoid copying) # explicitly tell np.array to try and avoid copying)
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32) data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan missing = missing if missing is not None else np.nan
@ -1391,8 +1391,9 @@ class Booster(object):
value of the prediction. Note the last row and column correspond to the bias term. value of the prediction. Note the last row and column correspond to the bias term.
validate_features : bool validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical. When this is True, validate that the Booster's and data's
Otherwise, it is assumed that the feature_names are the same. feature_names are identical. Otherwise, it is assumed that the
feature_names are the same.
Returns Returns
------- -------
@ -1811,8 +1812,8 @@ class Booster(object):
msg = 'feature_names mismatch: {0} {1}' msg = 'feature_names mismatch: {0} {1}'
if dat_missing: if dat_missing:
msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) + msg += ('\nexpected ' + ', '.join(
' in input data') str(s) for s in dat_missing) + ' in input data')
if my_missing: if my_missing:
msg += ('\ntraining data did not have the following fields: ' + msg += ('\ntraining data did not have the following fields: ' +
@ -1821,7 +1822,8 @@ class Booster(object):
raise ValueError(msg.format(self.feature_names, raise ValueError(msg.format(self.feature_names,
data.feature_names)) data.feature_names))
def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True): def get_split_value_histogram(self, feature, fmap='', bins=None,
as_pandas=True):
"""Get split value histogram of a feature """Get split value histogram of a feature
Parameters Parameters

View File

@ -55,10 +55,14 @@ def _start_tracker(host, n_workers):
return env return env
def _assert_dask_installed(): def _assert_dask_support():
if not DASK_INSTALLED: if not DASK_INSTALLED:
raise ImportError( raise ImportError(
'Dask needs to be installed in order to use this module') 'Dask needs to be installed in order to use this module')
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
logging.warning(msg)
class RabitContext: class RabitContext:
@ -96,6 +100,11 @@ def _xgb_get_client(client):
return ret return ret
def _get_client_workers(client):
workers = client.scheduler_info()['workers']
return workers
class DaskDMatrix: class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes # pylint: disable=missing-docstring, too-many-instance-attributes
'''DMatrix holding on references to Dask DataFrame or Dask Array. '''DMatrix holding on references to Dask DataFrame or Dask Array.
@ -132,7 +141,7 @@ class DaskDMatrix:
weight=None, weight=None,
feature_names=None, feature_names=None,
feature_types=None): feature_types=None):
_assert_dask_installed() _assert_dask_support()
self._feature_names = feature_names self._feature_names = feature_names
self._feature_types = feature_types self._feature_types = feature_types
@ -263,6 +272,17 @@ class DaskDMatrix:
A DMatrix object. A DMatrix object.
''' '''
if worker.address not in set(self.worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
logging.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=self._feature_names,
feature_types=self._feature_types)
return d
data, labels, weights = self.get_worker_parts(worker) data, labels, weights = self.get_worker_parts(worker)
data = concat(data) data = concat(data)
@ -275,7 +295,6 @@ class DaskDMatrix:
weights = concat(weights) weights = concat(weights)
else: else:
weights = None weights = None
dmatrix = DMatrix(data, dmatrix = DMatrix(data,
labels, labels,
weight=weights, weight=weights,
@ -342,35 +361,33 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
'eval': {'logloss': ['0.480385', '0.357756']}}} 'eval': {'logloss': ['0.480385', '0.357756']}}}
''' '''
_assert_dask_installed() _assert_dask_support()
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
logging.warning(msg)
if 'evals_result' in kwargs.keys(): if 'evals_result' in kwargs.keys():
raise ValueError( raise ValueError(
'evals_result is not supported in dask interface.', 'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.') 'The evaluation history is returned as result of training.')
client = _xgb_get_client(client) client = _xgb_get_client(client)
workers = list(_get_client_workers(client).keys())
worker_map = dtrain.worker_map rabit_args = _get_rabit_args(workers, client)
rabit_args = _get_rabit_args(worker_map, client)
def dispatched_train(worker_id): def dispatched_train(worker_addr):
'''Perform training on worker.''' '''Perform training on a single worker.'''
logging.info('Training on %d', worker_id) logging.info('Training on %s', str(worker_addr))
worker = distributed_get_worker() worker = distributed_get_worker()
local_dtrain = dtrain.get_worker_data(worker)
local_evals = []
if evals:
for mat, name in evals:
local_mat = mat.get_worker_data(worker)
local_evals.append((local_mat, name))
with RabitContext(rabit_args): with RabitContext(rabit_args):
local_dtrain = dtrain.get_worker_data(worker)
local_evals = []
if evals:
for mat, name in evals:
if mat is dtrain:
local_evals.append((local_dtrain, name))
continue
local_mat = mat.get_worker_data(worker)
local_evals.append((local_mat, name))
local_history = {} local_history = {}
local_param = params.copy() # just to be consistent local_param = params.copy() # just to be consistent
bst = worker_train(params=local_param, bst = worker_train(params=local_param,
@ -380,14 +397,14 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
evals=local_evals, evals=local_evals,
**kwargs) **kwargs)
ret = {'booster': bst, 'history': local_history} ret = {'booster': bst, 'history': local_history}
if rabit.get_rank() != 0: if local_dtrain.num_row() == 0:
ret = None ret = None
return ret return ret
futures = client.map(dispatched_train, futures = client.map(dispatched_train,
range(len(worker_map)), workers,
pure=False, pure=False,
workers=list(worker_map.keys())) workers=workers)
results = client.gather(futures) results = client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0] return list(filter(lambda ret: ret is not None, results))[0]
@ -414,7 +431,7 @@ def predict(client, model, data, *args):
prediction: dask.array.Array prediction: dask.array.Array
''' '''
_assert_dask_installed() _assert_dask_support()
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model booster = model
elif isinstance(model, dict): elif isinstance(model, dict):
@ -437,7 +454,8 @@ def predict(client, model, data, *args):
local_x = data.get_worker_data(worker) local_x = data.get_worker_data(worker)
with RabitContext(rabit_args): with RabitContext(rabit_args):
local_predictions = booster.predict(data=local_x, *args) local_predictions = booster.predict(
data=local_x, validate_features=local_x.num_row() != 0, *args)
return local_predictions return local_predictions
futures = client.map(dispatched_predict, futures = client.map(dispatched_predict,
@ -563,7 +581,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
sample_weights=None, sample_weights=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None): sample_weight_eval_set=None):
_assert_dask_installed() _assert_dask_support()
dtrain = DaskDMatrix(client=self.client, dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights) data=X, label=y, weight=sample_weights)
params = self.get_xgb_params() params = self.get_xgb_params()
@ -579,7 +597,7 @@ class DaskXGBRegressor(DaskScikitLearnBase):
return self return self
def predict(self, data): # pylint: disable=arguments-differ def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_installed() _assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data) test_dmatrix = DaskDMatrix(client=self.client, data=data)
pred_probs = predict(client=self.client, pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix)
@ -599,7 +617,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weights=None, sample_weights=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None): sample_weight_eval_set=None):
_assert_dask_installed() _assert_dask_support()
dtrain = DaskDMatrix(client=self.client, dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights) data=X, label=y, weight=sample_weights)
params = self.get_xgb_params() params = self.get_xgb_params()
@ -626,7 +644,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
return self return self
def predict(self, data): # pylint: disable=arguments-differ def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_installed() _assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data) test_dmatrix = DaskDMatrix(client=self.client, data=data)
pred_probs = predict(client=self.client, pred_probs = predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix)

View File

@ -332,7 +332,7 @@ class RabitTracker(object):
self.thread.start() self.thread.start()
def join(self): def join(self):
while self.thread.isAlive(): while self.thread.is_alive():
self.thread.join(100) self.thread.join(100)
def alive(self): def alive(self):

View File

@ -1,5 +1,5 @@
file(GLOB_RECURSE CPU_SOURCES *.cc *.h) file(GLOB_RECURSE CPU_SOURCES *.cc *.h)
list(REMOVE_ITEM CPU_SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc) list(REMOVE_ITEM CPU_SOURCES ${xgboost_SOURCE_DIR}/src/cli_main.cc)
#-- Object library #-- Object library
# Object library is necessary for jvm-package, which creates its own shared # Object library is necessary for jvm-package, which creates its own shared
@ -9,7 +9,7 @@ if (USE_CUDA)
add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES}) add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES})
target_compile_definitions(objxgboost target_compile_definitions(objxgboost
PRIVATE -DXGBOOST_USE_CUDA=1) PRIVATE -DXGBOOST_USE_CUDA=1)
target_include_directories(objxgboost PRIVATE ${PROJECT_SOURCE_DIR}/cub/) target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/)
target_compile_options(objxgboost PRIVATE target_compile_options(objxgboost PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda> $<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr> $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
@ -43,9 +43,9 @@ endif (USE_CUDA)
target_include_directories(objxgboost target_include_directories(objxgboost
PRIVATE PRIVATE
${PROJECT_SOURCE_DIR}/include ${xgboost_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dmlc-core/include ${xgboost_SOURCE_DIR}/dmlc-core/include
${PROJECT_SOURCE_DIR}/rabit/include) ${xgboost_SOURCE_DIR}/rabit/include)
target_compile_options(objxgboost target_compile_options(objxgboost
PRIVATE PRIVATE
$<$<AND:$<CXX_COMPILER_ID:MSVC>,$<COMPILE_LANGUAGE:CXX>>:/MP> $<$<AND:$<CXX_COMPILER_ID:MSVC>,$<COMPILE_LANGUAGE:CXX>>:/MP>

View File

@ -0,0 +1,91 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*
* \brief Utilities for CUDA.
*/
#ifdef XGBOOST_USE_NCCL
#include <nccl.h>
#endif // #ifdef XGBOOST_USE_NCCL
#include <sstream>
#include "device_helpers.cuh"
namespace dh {
#if __CUDACC_VER_MAJOR__ > 9
constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(int world_size, int rank, int device_ord,
xgboost::common::Span<uint64_t, kUuidLength> uuid) {
cudaDeviceProp prob;
safe_cuda(cudaGetDeviceProperties(&prob, device_ord));
std::memcpy(uuid.data(), static_cast<void*>(&(prob.uuid)), sizeof(prob.uuid));
}
std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
#endif // __CUDACC_VER_MAJOR__ > 9
void AllReducer::Init(int _device_ordinal) {
#ifdef XGBOOST_USE_NCCL
LOG(DEBUG) << "Running nccl init on: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__;
device_ordinal = _device_ordinal;
int32_t const rank = rabit::GetRank();
#if __CUDACC_VER_MAJOR__ > 9
int32_t const world = rabit::GetWorldSize();
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(world, rank, device_ordinal, s_this_uuid);
// No allgather yet.
rabit::Allreduce<rabit::op::Sum, uint64_t>(uuids.data(), uuids.size());
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);;
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] =
xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported";
#endif // __CUDACC_VER_MAJOR__ > 9
id = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rank));
safe_cuda(cudaStreamCreate(&stream));
initialised_ = true;
#endif // XGBOOST_USE_NCCL
}
AllReducer::~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised_) {
dh::safe_cuda(cudaStreamDestroy(stream));
ncclCommDestroy(comm);
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
}
#endif // XGBOOST_USE_NCCL
}
} // namespace dh

View File

@ -7,24 +7,25 @@
#include <thrust/device_malloc_allocator.h> #include <thrust/device_malloc_allocator.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <xgboost/logging.h>
#include <omp.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh> #include <cub/util_allocator.cuh>
#include "xgboost/host_device_vector.h"
#include "xgboost/span.h"
#include "common.h"
#include <algorithm> #include <algorithm>
#include <omp.h>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include <cub/cub.cuh>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "xgboost/logging.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/span.h"
#include "common.h"
#include "timer.h" #include "timer.h"
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
@ -205,24 +206,53 @@ __global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
} }
template <typename L> template <typename L>
__global__ void LaunchNKernel(int device_idx, size_t begin, size_t end, __global__ void LaunchNKernel(int device_idx, size_t begin, size_t end,
L lambda) { L lambda) {
for (auto i : GridStrideRange(begin, end)) { for (auto i : GridStrideRange(begin, end)) {
lambda(i, device_idx); lambda(i, device_idx);
} }
} }
/* \brief A wrapper around kernel launching syntax, used to guard against empty input.
*
* - nvcc fails to deduce template argument when kernel is a template accepting __device__
* function as argument. Hence functions like `LaunchN` cannot use this wrapper.
*
* - With c++ initialization list `{}` syntax, you are forced to comply with the CUDA type
* spcification.
*/
class LaunchKernel {
size_t shmem_size_;
cudaStream_t stream_;
dim3 grids_;
dim3 blocks_;
public:
LaunchKernel(uint32_t _grids, uint32_t _blk, size_t _shmem=0, cudaStream_t _s=0) :
grids_{_grids, 1, 1}, blocks_{_blk, 1, 1}, shmem_size_{_shmem}, stream_{_s} {}
LaunchKernel(dim3 _grids, dim3 _blk, size_t _shmem=0, cudaStream_t _s=0) :
grids_{_grids}, blocks_{_blk}, shmem_size_{_shmem}, stream_{_s} {}
template <typename K, typename... Args>
void operator()(K kernel, Args... args) {
if (XGBOOST_EXPECT(grids_.x * grids_.y * grids_.z == 0, false)) {
LOG(DEBUG) << "Skipping empty CUDA kernel.";
return;
}
kernel<<<grids_, blocks_, shmem_size_, stream_>>>(args...); // NOLINT
}
};
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L> template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
if (n == 0) { if (n == 0) {
return; return;
} }
safe_cuda(cudaSetDevice(device_idx)); safe_cuda(cudaSetDevice(device_idx));
const int GRID_SIZE = const int GRID_SIZE =
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS)); static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>(static_cast<size_t>(0), LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>( // NOLINT
n, lambda); static_cast<size_t>(0), n, lambda);
} }
// Default stream version // Default stream version
@ -301,6 +331,16 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
return memory_logger; return memory_logger;
} }
// dh::DebugSyncDevice(__FILE__, __LINE__);
inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
if (file != "" && line != -1) {
auto rank = rabit::GetRank();
LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line;
}
safe_cuda(cudaDeviceSynchronize());
safe_cuda(cudaGetLastError());
}
namespace detail{ namespace detail{
/** /**
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose. * \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
@ -763,7 +803,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
BLOCK_THREADS, segments, num_segments, count); BLOCK_THREADS, segments, num_segments, count);
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, OffsetT> LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, OffsetT>
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, <<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, // NOLINT
segments + 1, f, num_segments); segments + 1, f, num_segments);
} }
@ -963,7 +1003,6 @@ class SaveCudaContext {
* streams. Must be initialised before use. If XGBoost is compiled without NCCL * streams. Must be initialised before use. If XGBoost is compiled without NCCL
* this is a dummy class that will error if used with more than one GPU. * this is a dummy class that will error if used with more than one GPU.
*/ */
class AllReducer { class AllReducer {
bool initialised_; bool initialised_;
size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated
@ -986,31 +1025,9 @@ class AllReducer {
* *
* \param device_ordinal The device ordinal. * \param device_ordinal The device ordinal.
*/ */
void Init(int _device_ordinal);
void Init(int _device_ordinal) { ~AllReducer();
#ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
device_ordinal = _device_ordinal;
id = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rabit::GetRank()));
safe_cuda(cudaStreamCreate(&stream));
initialised_ = true;
#endif
}
~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised_) {
dh::safe_cuda(cudaStreamDestroy(stream));
ncclCommDestroy(comm);
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576;
}
#endif
}
/** /**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing * \brief Allreduce. Use in exactly the same way as NCCL but without needing

View File

@ -293,6 +293,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
void DenseCuts::Init void DenseCuts::Init
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) { (std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
monitor_.Start(__func__);
std::vector<WXQSketch>& sketchs = *in_sketchs; std::vector<WXQSketch>& sketchs = *in_sketchs;
constexpr int kFactor = 8; constexpr int kFactor = 8;
// gather the histogram data // gather the histogram data
@ -332,6 +333,7 @@ void DenseCuts::Init
CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back()); CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back());
p_cuts_->cut_ptrs_.push_back(cut_size); p_cuts_->cut_ptrs_.push_back(cut_size);
} }
monitor_.Stop(__func__);
} }
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {

View File

@ -252,8 +252,10 @@ class GPUSketcher {
}); });
} else if (n_cuts_cur_[icol] > 0) { } else if (n_cuts_cur_[icol] > 0) {
// if more elements than cuts: use binary search on cumulative weights // if more elements than cuts: use binary search on cumulative weights
int block = 256; uint32_t constexpr kBlockThreads = 256;
FindCutsK<<<common::DivRoundUp(n_cuts_cur_[icol], block), block>>>( uint32_t const kGrids = common::DivRoundUp(n_cuts_cur_[icol], kBlockThreads);
dh::LaunchKernel {kGrids, kBlockThreads} (
FindCutsK,
cuts_d_.data().get() + icol * n_cuts_, cuts_d_.data().get() + icol * n_cuts_,
fvalues_cur_.data().get(), fvalues_cur_.data().get(),
weights2_.data().get(), weights2_.data().get(),
@ -403,7 +405,8 @@ class GPUSketcher {
// NOTE: This will typically support ~ 4M features - 64K*64 // NOTE: This will typically support ~ 4M features - 64K*64
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(num_cols_, block3.y), 1); common::DivRoundUp(num_cols_, block3.y), 1);
UnpackFeaturesK<<<grid3, block3>>>( dh::LaunchKernel {grid3, block3} (
UnpackFeaturesK,
fvalues_.data().get(), fvalues_.data().get(),
has_weights_ ? feature_weights_.data().get() : nullptr, has_weights_ ? feature_weights_.data().get() : nullptr,
row_ptrs_.data().get() + batch_row_begin, row_ptrs_.data().get() + batch_row_begin,

View File

@ -13,6 +13,20 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
void Monitor::Start(std::string const &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
statistics_map[name].timer.Start();
}
}
void Monitor::Stop(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
}
}
std::vector<Monitor::StatMap> Monitor::CollectFromOtherRanks() const { std::vector<Monitor::StatMap> Monitor::CollectFromOtherRanks() const {
// Since other nodes might have started timers that this one haven't, so // Since other nodes might have started timers that this one haven't, so
// we can't simply call all reduce. // we can't simply call all reduce.

38
src/common/timer.cu Normal file
View File

@ -0,0 +1,38 @@
/*!
* Copyright by Contributors 2019
*/
#if defined(XGBOOST_USE_NVTX)
#include <nvToolsExt.h>
#endif // defined(XGBOOST_USE_NVTX)
#include <string>
#include "xgboost/logging.h"
#include "device_helpers.cuh"
#include "timer.h"
namespace xgboost {
namespace common {
void Monitor::StartCuda(const std::string& name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Start();
#if defined(XGBOOST_USE_NVTX)
stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif // defined(XGBOOST_USE_NVTX)
}
}
void Monitor::StopCuda(const std::string& name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
#if defined(XGBOOST_USE_NVTX)
nvtxRangeEnd(stats.nvtx_id);
#endif // defined(XGBOOST_USE_NVTX)
}
}
} // namespace common
} // namespace xgboost

View File

@ -10,10 +10,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
#include <nvToolsExt.h>
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -84,37 +80,10 @@ struct Monitor {
void Print() const; void Print() const;
void Init(std::string label) { this->label = label; } void Init(std::string label) { this->label = label; }
void Start(const std::string &name) { void Start(const std::string &name);
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { void Stop(const std::string &name);
statistics_map[name].timer.Start(); void StartCuda(const std::string &name);
} void StopCuda(const std::string &name);
}
void Stop(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
}
}
void StartCuda(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Start();
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
}
}
void StopCuda(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
nvtxRangeEnd(stats.nvtx_id);
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
}
}
}; };
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -133,9 +133,12 @@ class Transform {
size_t shard_size = range_size; size_t shard_size = range_size;
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)}; Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
const int GRID_SIZE = const int kGrids =
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads)); static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>( if (kGrids == 0) {
return;
}
detail::LaunchCUDAKernel<<<kGrids, kBlockThreads>>>( // NOLINT
_func, shard_range, UnpackHDVOnDevice(_vectors)...); _func, shard_range, UnpackHDVOnDevice(_vectors)...);
} }
#else #else

View File

@ -320,6 +320,32 @@ void DMatrix::SaveToLocalFile(const std::string& fname) {
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source, DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
const std::string& cache_prefix) { const std::string& cache_prefix) {
if (cache_prefix.length() == 0) { if (cache_prefix.length() == 0) {
// FIXME(trivialfis): Currently distcol is broken so we here check for number of rows.
// If we bring back column split this check will break.
bool is_distributed { rabit::IsDistributed() };
if (is_distributed) {
auto world_size = rabit::GetWorldSize();
auto rank = rabit::GetRank();
std::vector<uint64_t> ncols(world_size, 0);
ncols[rank] = source->info.num_col_;
rabit::Allreduce<rabit::op::Sum>(ncols.data(), ncols.size());
auto max_cols = std::max_element(ncols.cbegin(), ncols.cend());
auto max_ind = std::distance(ncols.cbegin(), max_cols);
// FIXME(trivialfis): This is a hack, we should store a reference to global shape if possible.
if (source->info.num_col_ == 0 && source->info.num_row_ == 0) {
LOG(WARNING) << "DMatrix at rank: " << rank << " worker is empty.";
source->info.num_col_ = *max_cols;
}
// validate the number of columns across all workers.
for (size_t i = 0; i < ncols.size(); ++i) {
auto v = ncols[i];
CHECK(v == 0 || v == *max_cols)
<< "DMatrix at rank: " << i << " worker "
<< "has different number of columns than rank: " << max_ind << " worker. "
<< "(" << v << " vs. " << *max_cols << ")";
}
}
return new data::SimpleDMatrix(std::move(source)); return new data::SimpleDMatrix(std::move(source));
} else { } else {
#if DMLC_ENABLE_STD_THREAD #if DMLC_ENABLE_STD_THREAD

View File

@ -99,13 +99,13 @@ EllpackInfo::EllpackInfo(int device,
bool is_dense, bool is_dense,
size_t row_stride, size_t row_stride,
const common::HistogramCuts& hmat, const common::HistogramCuts& hmat,
dh::BulkAllocator& ba) dh::BulkAllocator* ba)
: is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) { : is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) {
ba.Allocate(device, ba->Allocate(device,
&feature_segments, hmat.Ptrs().size(), &feature_segments, hmat.Ptrs().size(),
&gidx_fvalue_map, hmat.Values().size(), &gidx_fvalue_map, hmat.Values().size(),
&min_fvalue, hmat.MinValues().size()); &min_fvalue, hmat.MinValues().size());
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values()); dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values());
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues()); dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues());
dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs()); dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs());
@ -116,7 +116,7 @@ void EllpackPageImpl::InitInfo(int device,
bool is_dense, bool is_dense,
size_t row_stride, size_t row_stride,
const common::HistogramCuts& hmat) { const common::HistogramCuts& hmat) {
matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, ba_); matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, &ba_);
} }
// Initialize the buffer to stored compressed features. // Initialize the buffer to stored compressed features.
@ -189,7 +189,8 @@ void EllpackPageImpl::CreateHistIndices(int device,
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), common::DivRoundUp(row_stride, block3.y),
1); 1);
CompressBinEllpackKernel<<<grid3, block3>>>( dh::LaunchKernel {grid3, block3} (
CompressBinEllpackKernel,
common::CompressedBufferWriter(num_symbols), common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(), gidx_buffer.data(),
row_ptrs.data().get(), row_ptrs.data().get(),

View File

@ -70,7 +70,7 @@ struct EllpackInfo {
bool is_dense, bool is_dense,
size_t row_stride, size_t row_stride,
const common::HistogramCuts& hmat, const common::HistogramCuts& hmat,
dh::BulkAllocator& ba); dh::BulkAllocator* ba);
}; };
/** \brief Struct for accessing and manipulating an ellpack matrix on the /** \brief Struct for accessing and manipulating an ellpack matrix on the

View File

@ -85,7 +85,7 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
monitor_.StopCuda("Quantiles"); monitor_.StopCuda("Quantiles");
monitor_.StartCuda("CreateEllpackInfo"); monitor_.StartCuda("CreateEllpackInfo");
ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, ba_); ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, &ba_);
monitor_.StopCuda("CreateEllpackInfo"); monitor_.StopCuda("CreateEllpackInfo");
monitor_.StartCuda("WriteEllpackPages"); monitor_.StartCuda("WriteEllpackPages");

View File

@ -101,7 +101,7 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
HostDeviceVector<size_t>* out_offset, HostDeviceVector<size_t>* out_offset,
dh::caching_device_vector<int32_t>* out_d_flag, dh::caching_device_vector<int32_t>* out_d_flag,
uint32_t* out_n_rows) { uint32_t* out_n_rows) {
int32_t constexpr kThreads = 256; uint32_t constexpr kThreads = 256;
auto const& j_column = j_columns[column_id]; auto const& j_column = j_columns[column_id];
auto const& column_obj = get<Object const>(j_column); auto const& column_obj = get<Object const>(j_column);
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj); Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
@ -123,8 +123,9 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
common::Span<size_t> s_offsets = out_offset->DeviceSpan(); common::Span<size_t> s_offsets = out_offset->DeviceSpan();
int32_t const kBlocks = common::DivRoundUp(n_rows, kThreads); uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
CountValidKernel<T><<<kBlocks, kThreads>>>( dh::LaunchKernel {kBlocks, kThreads} (
CountValidKernel<T>,
foreign_column, foreign_column,
has_missing, missing, has_missing, missing,
out_d_flag->data().get(), s_offsets); out_d_flag->data().get(), s_offsets);
@ -135,13 +136,15 @@ template <typename T>
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows, void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
bool has_missing, float missing, bool has_missing, float missing,
dh::device_vector<size_t>* tmp_offset, common::Span<Entry> s_data) { dh::device_vector<size_t>* tmp_offset, common::Span<Entry> s_data) {
int32_t constexpr kThreads = 256; uint32_t constexpr kThreads = 256;
auto const& j_column = j_columns[column_id]; auto const& j_column = j_columns[column_id];
auto const& column_obj = get<Object const>(j_column); auto const& column_obj = get<Object const>(j_column);
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj); Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads); uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
CreateCSRKernel<T><<<kBlocks, kThreads>>>(foreign_column, column_id, has_missing, missing, dh::LaunchKernel {kBlocks, kThreads} (
dh::ToSpan(*tmp_offset), s_data); CreateCSRKernel<T>,
foreign_column, column_id, has_missing, missing,
dh::ToSpan(*tmp_offset), s_data);
} }
void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns, void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,

View File

@ -246,6 +246,14 @@ class GBTree : public GradientBooster {
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr, std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
DMatrix* f_dmat = nullptr) const { DMatrix* f_dmat = nullptr) const {
CHECK(configured_); CHECK(configured_);
auto on_device = f_dmat && (*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
#if defined(XGBOOST_USE_CUDA)
// Use GPU Predictor if data is already on device.
if (!specified_predictor_ && on_device) {
CHECK(gpu_predictor_);
return gpu_predictor_;
}
#endif // defined(XGBOOST_USE_CUDA)
// GPU_Hist by default has prediction cache calculated from quantile values, so GPU // GPU_Hist by default has prediction cache calculated from quantile values, so GPU
// Predictor is not used for training dataset. But when XGBoost performs continue // Predictor is not used for training dataset. But when XGBoost performs continue
// training with an existing model, the prediction cache is not availbale and number // training with an existing model, the prediction cache is not availbale and number
@ -256,7 +264,7 @@ class GBTree : public GradientBooster {
(model_.param.num_trees != 0) && (model_.param.num_trees != 0) &&
// FIXME(trivialfis): Implement a better method for testing whether data is on // FIXME(trivialfis): Implement a better method for testing whether data is on
// device after DMatrix refactoring is done. // device after DMatrix refactoring is done.
(f_dmat && !((*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead()))) { !on_device) {
return cpu_predictor_; return cpu_predictor_;
} }
if (tparam_.predictor == "cpu_predictor") { if (tparam_.predictor == "cpu_predictor") {

View File

@ -630,7 +630,7 @@ class LearnerImpl : public Learner {
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max())) CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with " << "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater"; << std::numeric_limits<unsigned>::max() << " features or greater";
num_feature = std::max(num_feature, static_cast<unsigned>(num_col)); num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
} }
// run allreduce on num_feature to find the maximum value // run allreduce on num_feature to find the maximum value
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature"); rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");

View File

@ -3,6 +3,8 @@
* \file elementwise_metric.cc * \file elementwise_metric.cc
* \brief evaluation metrics for elementwise binary or regression. * \brief evaluation metrics for elementwise binary or regression.
* \author Kailong Chen, Tianqi Chen * \author Kailong Chen, Tianqi Chen
*
* The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset.
*/ */
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/metric.h> #include <xgboost/metric.h>
@ -142,7 +144,7 @@ struct EvalRowRMSE {
return diff * diff; return diff * diff;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return std::sqrt(esum / wsum); return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum);
} }
}; };
@ -150,12 +152,13 @@ struct EvalRowRMSLE {
char const* Name() const { char const* Name() const {
return "rmsle"; return "rmsle";
} }
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float diff = std::log1p(label) - std::log1p(pred); bst_float diff = std::log1p(label) - std::log1p(pred);
return diff * diff; return diff * diff;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return std::sqrt(esum / wsum); return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum);
} }
}; };
@ -168,7 +171,7 @@ struct EvalRowMAE {
return std::abs(label - pred); return std::abs(label - pred);
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -190,7 +193,7 @@ struct EvalRowLogLoss {
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -225,7 +228,7 @@ struct EvalError {
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
private: private:
@ -245,7 +248,7 @@ struct EvalPoissonNegLogLik {
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -278,7 +281,7 @@ struct EvalGammaNLogLik {
return -((y * theta - b) / a + c); return -((y * theta - b) / a + c);
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
}; };
@ -304,7 +307,7 @@ struct EvalTweedieNLogLik {
return -a + b; return -a + b;
} }
static bst_float GetFinal(bst_float esum, bst_float wsum) { static bst_float GetFinal(bst_float esum, bst_float wsum) {
return esum / wsum; return wsum == 0 ? esum : esum / wsum;
} }
protected: protected:
@ -323,7 +326,9 @@ struct EvalEWiseBase : public Metric {
bst_float Eval(const HostDeviceVector<bst_float>& preds, bst_float Eval(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info, const MetaInfo& info,
bool distributed) override { bool distributed) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; if (info.labels_.Size() == 0) {
LOG(WARNING) << "label set is empty";
}
CHECK_EQ(preds.Size(), info.labels_.Size()) CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label and prediction size not match, " << "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification"; << "hint: use merror or mlogloss for multi-class classification";
@ -333,6 +338,7 @@ struct EvalEWiseBase : public Metric {
reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds); reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds);
double dat[2] { result.Residue(), result.Weights() }; double dat[2] { result.Residue(), result.Weights() };
if (distributed) { if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2); rabit::Allreduce<rabit::op::Sum>(dat, 2);
} }

View File

@ -54,7 +54,9 @@ class RegLossObj : public ObjFunction {
const MetaInfo &info, const MetaInfo &info,
int iter, int iter,
HostDeviceVector<GradientPair>* out_gpair) override { HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; if (info.labels_.Size() == 0U) {
LOG(WARNING) << "Label set is empty.";
}
CHECK_EQ(preds.Size(), info.labels_.Size()) CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided" << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();

View File

@ -60,6 +60,9 @@ class CPUPredictor : public Predictor {
constexpr int kUnroll = 8; constexpr int kUnroll = 8;
const auto nsize = static_cast<bst_omp_uint>(batch.Size()); const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const bst_omp_uint rest = nsize % kUnroll; const bst_omp_uint rest = nsize % kUnroll;
// Pull to host before entering omp block, as this is not thread safe.
batch.data.HostVector();
batch.offset.HostVector();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();

View File

@ -225,12 +225,12 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* predictions, HostDeviceVector<bst_float>* predictions,
size_t batch_offset) { size_t batch_offset) {
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
const int BLOCK_THREADS = 128; const uint32_t BLOCK_THREADS = 128;
size_t num_rows = batch.Size(); size_t num_rows = batch.Size();
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
int shared_memory_bytes = static_cast<int> auto shared_memory_bytes =
(sizeof(float) * num_features * BLOCK_THREADS); static_cast<size_t>(sizeof(float) * num_features * BLOCK_THREADS);
bool use_shared = true; bool use_shared = true;
if (shared_memory_bytes > max_shared_memory_bytes_) { if (shared_memory_bytes > max_shared_memory_bytes_) {
shared_memory_bytes = 0; shared_memory_bytes = 0;
@ -238,11 +238,12 @@ class GPUPredictor : public xgboost::Predictor {
} }
size_t entry_start = 0; size_t entry_start = 0;
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>> dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), PredictKernel<BLOCK_THREADS>,
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(), dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows, dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
entry_start, use_shared, this->num_group_); batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
entry_start, use_shared, this->num_group_);
} }
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {

View File

@ -165,10 +165,11 @@ __global__ void ClearBuffersKernel(
void FeatureInteractionConstraint::ClearBuffers() { void FeatureInteractionConstraint::ClearBuffers() {
CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size()); CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size());
CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size()); CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256; uint32_t constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>( auto const n_grids = static_cast<uint32_t>(
common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads));
ClearBuffersKernel<<<n_grids, kBlockThreads>>>( dh::LaunchKernel {n_grids, kBlockThreads} (
ClearBuffersKernel,
output_buffer_bits_, input_buffer_bits_); output_buffer_bits_, input_buffer_bits_);
} }
@ -222,12 +223,14 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
LBitField64 node_constraints = s_node_constraints_[nid]; LBitField64 node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size()); CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256; uint32_t constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>( auto n_grids = static_cast<uint32_t>(
common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads));
SetInputBufferKernel<<<n_grids, kBlockThreads>>>(feature_list, input_buffer_bits_); dh::LaunchKernel {n_grids, kBlockThreads} (
SetInputBufferKernel,
QueryFeatureListKernel<<<n_grids, kBlockThreads>>>( feature_list, input_buffer_bits_);
dh::LaunchKernel {n_grids, kBlockThreads} (
QueryFeatureListKernel,
node_constraints, input_buffer_bits_, output_buffer_bits_); node_constraints, input_buffer_bits_, output_buffer_bits_);
thrust::counting_iterator<int32_t> begin(0); thrust::counting_iterator<int32_t> begin(0);
@ -327,20 +330,20 @@ void FeatureInteractionConstraint::Split(
dim3 const block3(16, 64, 1); dim3 const block3(16, 64, 1);
dim3 const grid3(common::DivRoundUp(n_sets_, 16), dim3 const grid3(common::DivRoundUp(n_sets_, 16),
common::DivRoundUp(s_fconstraints_.size(), 64)); common::DivRoundUp(s_fconstraints_.size(), 64));
RestoreFeatureListFromSetsKernel<<<grid3, block3>>> dh::LaunchKernel {grid3, block3} (
(feature_buffer_, RestoreFeatureListFromSetsKernel,
feature_id, feature_buffer_, feature_id,
s_fconstraints_, s_fconstraints_, s_fconstraints_ptr_,
s_fconstraints_ptr_, s_sets_, s_sets_ptr_);
s_sets_,
s_sets_ptr_);
int constexpr kBlockThreads = 256; uint32_t constexpr kBlockThreads = 256;
const int n_grids = static_cast<int>(common::DivRoundUp(node.Size(), kBlockThreads)); auto n_grids = static_cast<uint32_t>(common::DivRoundUp(node.Size(), kBlockThreads));
InteractionConstraintSplitKernel<<<n_grids, kBlockThreads>>>
(feature_buffer_, dh::LaunchKernel {n_grids, kBlockThreads} (
feature_id, InteractionConstraintSplitKernel,
node, left, right); feature_buffer_,
feature_id,
node, left, right);
} }
} // namespace xgboost } // namespace xgboost

View File

@ -603,12 +603,12 @@ struct GPUHistMakerDevice {
} }
// One block for each feature // One block for each feature
int constexpr kBlockThreads = 256; uint32_t constexpr kBlockThreads = 256;
EvaluateSplitKernel<kBlockThreads, GradientSumT> dh::LaunchKernel {uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]} (
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>( EvaluateSplitKernel<kBlockThreads, GradientSumT>,
hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix, hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix,
gpu_param, d_split_candidates, node_value_constraints[nidx], gpu_param, d_split_candidates, node_value_constraints[nidx],
monotone_constraints); monotone_constraints);
// Reduce over features to find best feature // Reduce over features to find best feature
auto d_cub_memory = auto d_cub_memory =
@ -638,14 +638,12 @@ struct GPUHistMakerDevice {
use_shared_memory_histograms use_shared_memory_histograms
? sizeof(GradientSumT) * page->matrix.BinCount() ? sizeof(GradientSumT) * page->matrix.BinCount()
: 0; : 0;
const int items_per_thread = 8; uint32_t items_per_thread = 8;
const int block_threads = 256; uint32_t block_threads = 256;
const int grid_size = static_cast<int>( auto grid_size = static_cast<uint32_t>(
common::DivRoundUp(n_elements, items_per_thread * block_threads)); common::DivRoundUp(n_elements, items_per_thread * block_threads));
if (grid_size <= 0) { dh::LaunchKernel {grid_size, block_threads, smem_size} (
return; SharedMemHistKernel<GradientSumT>,
}
SharedMemHistKernel<<<grid_size, block_threads, smem_size>>>(
page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements, page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements,
use_shared_memory_histograms); use_shared_memory_histograms);
} }
@ -886,6 +884,7 @@ struct GPUHistMakerDevice {
monitor.StartCuda("InitRoot"); monitor.StartCuda("InitRoot");
this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_);
monitor.StopCuda("InitRoot"); monitor.StopCuda("InitRoot");
auto timestamp = qexpand->size(); auto timestamp = qexpand->size();
auto num_leaves = 1; auto num_leaves = 1;
@ -895,7 +894,6 @@ struct GPUHistMakerDevice {
if (!candidate.IsValid(param, num_leaves)) { if (!candidate.IsValid(param, num_leaves)) {
continue; continue;
} }
this->ApplySplit(candidate, p_tree); this->ApplySplit(candidate, p_tree);
num_leaves++; num_leaves++;
@ -996,18 +994,22 @@ class GPUHistMakerSpecialised {
try { try {
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {
this->UpdateTree(gpair, dmat, tree); this->UpdateTree(gpair, dmat, tree);
if (hist_maker_param_.debug_synchronize) {
this->CheckTreesSynchronized(tree);
}
} }
dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaGetLastError());
} catch (const std::exception& e) { } catch (const std::exception& e) {
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl; LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
} }
param_.learning_rate = lr; param_.learning_rate = lr;
monitor_.StopCuda("Update"); monitor_.StopCuda("Update");
} }
void InitDataOnce(DMatrix* dmat) { void InitDataOnce(DMatrix* dmat) {
info_ = &dmat->Info(); info_ = &dmat->Info();
reducer_.Init({device_}); reducer_.Init({device_});
// Synchronise the column sampling seed // Synchronise the column sampling seed
@ -1048,20 +1050,18 @@ class GPUHistMakerSpecialised {
} }
// Only call this method for testing // Only call this method for testing
void CheckTreesSynchronized(const std::vector<RegTree>& local_trees) const { void CheckTreesSynchronized(RegTree* local_tree) const {
std::string s_model; std::string s_model;
common::MemoryBufferStream fs(&s_model); common::MemoryBufferStream fs(&s_model);
int rank = rabit::GetRank(); int rank = rabit::GetRank();
if (rank == 0) { if (rank == 0) {
local_trees.front().SaveModel(&fs); local_tree->SaveModel(&fs);
} }
fs.Seek(0); fs.Seek(0);
rabit::Broadcast(&s_model, 0); rabit::Broadcast(&s_model, 0);
RegTree reference_tree{}; RegTree reference_tree {}; // rank 0 tree
reference_tree.LoadModel(&fs); reference_tree.LoadModel(&fs);
for (const auto& tree : local_trees) { CHECK(*local_tree == reference_tree);
CHECK(tree == reference_tree);
}
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,

View File

@ -18,7 +18,7 @@ ENV PATH=/opt/python/bin:$PATH
# Create new Conda environment with cuDF and dask # Create new Conda environment with cuDF and dask
RUN \ RUN \
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \ conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda
# Install other Python packages # Install other Python packages
RUN \ RUN \

View File

@ -17,7 +17,8 @@ ENV PATH=/opt/python/bin:$PATH
# Install Python packages # Install Python packages
RUN \ RUN \
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \ pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
pip install "dask[complete]" pip install "dask[complete]" && \
conda install -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda dask-cuda
ENV GOSU_VERSION 1.10 ENV GOSU_VERSION 1.10

View File

@ -21,18 +21,12 @@ RUN \
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) # NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
RUN \ RUN \
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \ export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
if [ "${CUDA_SHORT}" != "10.0" ] && [ "${CUDA_SHORT}" != "10.1" ]; then \ export NCCL_VERSION=2.4.8-1 && \
wget https://developer.download.nvidia.com/compute/redist/nccl/v2.2/nccl_2.2.13-1%2Bcuda${CUDA_SHORT}_x86_64.txz && \
tar xf "nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64.txz" && \
cp nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64/include/nccl.h /usr/include && \
cp nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64/lib/* /usr/lib && \
rm -f nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64.txz && \
rm -r nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64; else \
wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
yum -y update && \ yum -y update && \
yum install -y libnccl-2.4.2-1+cuda${CUDA_SHORT} libnccl-devel-2.4.2-1+cuda${CUDA_SHORT} libnccl-static-2.4.2-1+cuda${CUDA_SHORT} && \ yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \
rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm; fi rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm;
ENV PATH=/opt/python/bin:$PATH ENV PATH=/opt/python/bin:$PATH
ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc

View File

@ -33,11 +33,13 @@ case "$suite" in
pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu
cd tests/distributed cd tests/distributed
./runtests-gpu.sh ./runtests-gpu.sh
cd -
pytest -v -s --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py
;; ;;
cudf) cudf)
source activate cudf_test source activate cudf_test
python -m pytest -v -s --fulltrace tests/python-gpu/test_from_columnar.py tests/python-gpu/test_gpu_with_dask.py pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py
;; ;;
cpu) cpu)

View File

@ -19,7 +19,7 @@ if (USE_CUDA)
# OpenMP is mandatory for CUDA # OpenMP is mandatory for CUDA
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
target_include_directories(testxgboost PRIVATE target_include_directories(testxgboost PRIVATE
${PROJECT_SOURCE_DIR}/cub/) ${xgboost_SOURCE_DIR}/cub/)
target_compile_options(testxgboost PRIVATE target_compile_options(testxgboost PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda> $<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr> $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
@ -48,9 +48,9 @@ endif (USE_CUDA)
target_include_directories(testxgboost target_include_directories(testxgboost
PRIVATE PRIVATE
${GTEST_INCLUDE_DIRS} ${GTEST_INCLUDE_DIRS}
${PROJECT_SOURCE_DIR}/include ${xgboost_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dmlc-core/include ${xgboost_SOURCE_DIR}/dmlc-core/include
${PROJECT_SOURCE_DIR}/rabit/include) ${xgboost_SOURCE_DIR}/rabit/include)
set_target_properties( set_target_properties(
testxgboost PROPERTIES testxgboost PROPERTIES
CXX_STANDARD 11 CXX_STANDARD 11
@ -67,7 +67,7 @@ target_compile_definitions(testxgboost PRIVATE ${XGBOOST_DEFINITIONS})
if (USE_OPENMP) if (USE_OPENMP)
target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>) target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>)
endif (USE_OPENMP) endif (USE_OPENMP)
set_output_directory(testxgboost ${PROJECT_BINARY_DIR}) set_output_directory(testxgboost ${xgboost_BINARY_DIR})
# This grouping organises source files nicely in visual studio # This grouping organises source files nicely in visual studio
auto_source_group("${TEST_SOURCES}") auto_source_group("${TEST_SOURCES}")

View File

@ -2,6 +2,7 @@
import sys import sys
import time import time
import xgboost as xgb import xgboost as xgb
import os
def run_test(name, params_fun): def run_test(name, params_fun):
@ -48,6 +49,9 @@ def run_test(name, params_fun):
xgb.rabit.finalize() xgb.rabit.finalize()
if os.path.exists(model_name):
os.remove(model_name)
base_params = { base_params = {
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
@ -81,7 +85,5 @@ def wrap_rf(params_fun):
params_rf_1x4 = wrap_rf(params_basic_1x4) params_rf_1x4 = wrap_rf(params_basic_1x4)
test_name = sys.argv[1] test_name = sys.argv[1]
run_test(test_name, globals()['params_%s' % test_name]) run_test(test_name, globals()['params_%s' % test_name])

View File

@ -6,7 +6,7 @@ export DMLC_SUBMIT_CLUSTER=local
submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit" submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit"
echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
$submit --num-workers=4 python distributed_gpu.py basic_1x4 || exit 1 $submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py basic_1x4 || exit 1
echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
$submit --num-workers=4 python distributed_gpu.py rf_1x4 || exit 1 $submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py rf_1x4 || exit 1

3
tests/pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
markers =
mgpu: Mark a test that requires multiple GPUs to run.

View File

@ -2,6 +2,7 @@ import numpy as np
import sys import sys
import unittest import unittest
import pytest import pytest
import xgboost
sys.path.append("tests/python") sys.path.append("tests/python")
from regression_test_utilities import run_suite, parameter_combinations, \ from regression_test_utilities import run_suite, parameter_combinations, \
@ -21,7 +22,8 @@ datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
class TestGPU(unittest.TestCase): class TestGPU(unittest.TestCase):
def test_gpu_hist(self): def test_gpu_hist(self):
test_param = parameter_combinations({'gpu_id': [0], 'max_depth': [2, 8], test_param = parameter_combinations({'gpu_id': [0],
'max_depth': [2, 8],
'max_leaves': [255, 4], 'max_leaves': [255, 4],
'max_bin': [2, 256], 'max_bin': [2, 256],
'grow_policy': ['lossguide']}) 'grow_policy': ['lossguide']})
@ -36,6 +38,31 @@ class TestGPU(unittest.TestCase):
cpu_results = run_suite(param, select_datasets=datasets) cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results) assert_gpu_results(cpu_results, gpu_results)
def test_with_empty_dmatrix(self):
# FIXME(trivialfis): This should be done with all updaters
kRows = 0
kCols = 100
X = np.empty((kRows, kCols))
y = np.empty((kRows))
dtrain = xgboost.DMatrix(X, y)
bst = xgboost.train({'verbosity': 2,
'tree_method': 'gpu_hist',
'gpu_id': 0},
dtrain,
verbose_eval=True,
num_boost_round=6,
evals=[(dtrain, 'Train')])
kRows = 100
X = np.random.randn(kRows, kCols)
dtest = xgboost.DMatrix(X)
predictions = bst.predict(dtest)
np.testing.assert_allclose(predictions, 0.5, 1e-6)
@pytest.mark.mgpu @pytest.mark.mgpu
def test_specified_gpu_id_gpu_update(self): def test_specified_gpu_id_gpu_update(self):
variable_param = {'gpu_id': [1], variable_param = {'gpu_id': [1],

View File

@ -1,45 +1,94 @@
import sys import sys
import pytest import pytest
import numpy as np
import unittest
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
try: try:
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd import dask.dataframe as dd
from xgboost import dask as dxgb from xgboost import dask as dxgb
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import cudf import cudf
except ImportError: except ImportError:
client = None
loop = None
cluster_fixture = None
pass pass
sys.path.append("tests/python") sys.path.append("tests/python")
from test_with_dask import generate_array from test_with_dask import generate_array # noqa
import testing as tm import testing as tm # noqa
@pytest.mark.skipif(**tm.no_dask()) class TestDistributedGPU(unittest.TestCase):
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_dask_dataframe(client): @pytest.mark.skipif(**tm.no_dask_cudf())
X, y = generate_array() @pytest.mark.skipif(**tm.no_dask_cuda())
def test_dask_dataframe(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
X, y = generate_array()
X = dd.from_dask_array(X) X = dd.from_dask_array(X)
y = dd.from_dask_array(y) y = dd.from_dask_array(y)
X = X.map_partitions(cudf.from_pandas) X = X.map_partitions(cudf.from_pandas)
y = y.map_partitions(cudf.from_pandas) y = y.map_partitions(cudf.from_pandas)
dtrain = dxgb.DaskDMatrix(client, X, y) dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist'}, out = dxgb.train(client, {'tree_method': 'gpu_hist'},
dtrain=dtrain, dtrain=dtrain,
evals=[(dtrain, 'X')], evals=[(dtrain, 'X')],
num_boost_round=2) num_boost_round=2)
assert isinstance(out['booster'], dxgb.Booster) assert isinstance(out['booster'], dxgb.Booster)
assert len(out['history']['X']['rmse']) == 2 assert len(out['history']['X']['rmse']) == 2
predictions = dxgb.predict(out, dtrain) # FIXME(trivialfis): Re-enable this after #5003 is fixed
predictions = predictions.compute() # predictions = dxgb.predict(client, out, dtrain).compute()
# assert isinstance(predictions, np.ndarray)
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_empty_dmatrix(self):
def _check_outputs(out, predictions):
assert isinstance(out['booster'], dxgb.Booster)
assert len(out['history']['validation']['rmse']) == 2
assert isinstance(predictions, np.ndarray)
assert predictions.shape[0] == 1
parameters = {'tree_method': 'gpu_hist', 'verbosity': 3,
'debug_synchronize': True}
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
kRows, kCols = 1, 97
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.rand(kRows))
dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, parameters,
dtrain=dtrain,
evals=[(dtrain, 'validation')],
num_boost_round=2)
predictions = dxgb.predict(client=client, model=out,
data=dtrain).compute()
_check_outputs(out, predictions)
# train has more rows than evals
valid = dtrain
kRows += 1
X = dd.from_array(np.random.randn(kRows, kCols))
y = dd.from_array(np.random.rand(kRows))
dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, parameters,
dtrain=dtrain,
evals=[(valid, 'validation')],
num_boost_round=2)
predictions = dxgb.predict(client=client, model=out,
data=valid).compute()
_check_outputs(out, predictions)

View File

@ -67,7 +67,8 @@ def get_weights_regression(min_weight, max_weight):
n = 10000 n = 10000
sparsity = 0.25 sparsity = 0.25
X, y = datasets.make_regression(n, random_state=rng) X, y = datasets.make_regression(n, random_state=rng)
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X]) X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x
for x in x_row] for x_row in X])
w = np.array([rng.uniform(min_weight, max_weight) for i in range(n)]) w = np.array([rng.uniform(min_weight, max_weight) for i in range(n)])
return X, y, w return X, y, w

View File

@ -34,6 +34,15 @@ def no_matplotlib():
'reason': reason} 'reason': reason}
def no_dask_cuda():
reason = 'dask_cuda is not installed.'
try:
import dask_cuda as _ # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_cudf(): def no_cudf():
return {'condition': not CUDF_INSTALLED, return {'condition': not CUDF_INSTALLED,
'reason': 'CUDF is not installed'} 'reason': 'CUDF is not installed'}

View File

@ -34,6 +34,13 @@ fi
if [ ${TASK} == "cmake_test" ]; then if [ ${TASK} == "cmake_test" ]; then
set -e set -e
if grep -n -R '<<<.*>>>\(.*\)' src include | grep --invert "NOLINT"; then
echo 'Do not use raw CUDA execution configuration syntax with <<<blocks, threads>>>.' \
'try `dh::LaunchKernel`'
exit -1
fi
# Build/test # Build/test
rm -rf build rm -rf build
mkdir build && cd build mkdir build && cd build