From 7663de956c37eb4dd528132214e68ba2851d9696 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 6 Nov 2019 16:13:13 +0800 Subject: [PATCH] Run training with empty DMatrix. (#4990) This makes GPU Hist robust in distributed environment as some workers might not be associated with any data in either training or evaluation. * Disable rabit mock test for now: See #5012 . * Disable dask-cudf test at prediction for now: See #5003 * Launch dask job for all workers despite they might not have any data. * Check 0 rows in elementwise evaluation metrics. Using AUC and AUC-PR still throws an error. See #4663 for a robust fix. * Add tests for edge cases. * Add `LaunchKernel` wrapper handling zero sized grid. * Move some parts of allreducer into a cu file. * Don't validate feature names when the booster is empty. * Sync number of columns in DMatrix. As num_feature is required to be the same across all workers in data split mode. * Filtering in dask interface now by default syncs all booster that's not empty, instead of using rank 0. * Fix Jenkins' GPU tests. * Install dask-cuda from source in Jenkins' test. Now all tests are actually running. * Restore GPU Hist tree synchronization test. * Check UUID of running devices. The check is only performed on CUDA version >= 10.x, as 9.x doesn't have UUID field. * Fix CMake policy and project variables. Use xgboost_SOURCE_DIR uniformly, add policy for CMake >= 3.13. * Fix copying data to CPU * Fix race condition in cpu predictor. * Fix duplicated DMatrix construction. * Don't download extra nccl in CI script. --- CMakeLists.txt | 37 +++++---- Jenkinsfile | 1 - cmake/Doc.cmake | 2 +- cmake/Utils.cmake | 2 +- cmake/Version.cmake | 3 +- cmake/modules/FindNVML.cmake | 23 ++++++ python-package/xgboost/core.py | 20 ++--- python-package/xgboost/dask.py | 80 +++++++++++-------- python-package/xgboost/tracker.py | 2 +- src/CMakeLists.txt | 10 +-- src/common/device_helpers.cu | 91 +++++++++++++++++++++ src/common/device_helpers.cuh | 95 +++++++++++++--------- src/common/hist_util.cc | 2 + src/common/hist_util.cu | 9 ++- src/common/timer.cc | 14 ++++ src/common/timer.cu | 38 +++++++++ src/common/timer.h | 39 +-------- src/common/transform.h | 7 +- src/data/data.cc | 26 ++++++ src/data/ellpack_page.cu | 15 ++-- src/data/ellpack_page.cuh | 2 +- src/data/ellpack_page_source.cu | 2 +- src/data/simple_csr_source.cu | 17 ++-- src/gbm/gbtree.h | 10 ++- src/learner.cc | 2 +- src/metric/elementwise_metric.cu | 24 +++--- src/objective/regression_obj.cu | 4 +- src/predictor/cpu_predictor.cc | 3 + src/predictor/gpu_predictor.cu | 19 ++--- src/tree/constraints.cu | 45 ++++++----- src/tree/updater_gpu_hist.cu | 42 +++++----- tests/ci_build/Dockerfile.cudf | 2 +- tests/ci_build/Dockerfile.gpu | 3 +- tests/ci_build/Dockerfile.gpu_build | 12 +-- tests/ci_build/test_python.sh | 4 +- tests/cpp/CMakeLists.txt | 10 +-- tests/distributed/distributed_gpu.py | 6 +- tests/distributed/runtests-gpu.sh | 4 +- tests/pytest.ini | 3 + tests/python-gpu/test_gpu_updaters.py | 29 ++++++- tests/python-gpu/test_gpu_with_dask.py | 97 +++++++++++++++++------ tests/python/regression_test_utilities.py | 3 +- tests/python/testing.py | 9 +++ tests/travis/run_test.sh | 7 ++ 44 files changed, 603 insertions(+), 272 deletions(-) create mode 100644 cmake/modules/FindNVML.cmake create mode 100644 src/common/device_helpers.cu create mode 100644 src/common/timer.cu create mode 100644 tests/pytest.ini diff --git a/CMakeLists.txt b/CMakeLists.txt index f2316882b..206f881e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,13 @@ cmake_minimum_required(VERSION 3.3) project(xgboost LANGUAGES CXX C VERSION 1.0.0) include(cmake/Utils.cmake) -list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") +list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules") cmake_policy(SET CMP0022 NEW) +if ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13)) + cmake_policy(SET CMP0077 NEW) +endif ((${CMAKE_VERSION} VERSION_GREATER 3.13) OR (${CMAKE_VERSION} VERSION_EQUAL 3.13)) + message(STATUS "CMake version ${CMAKE_VERSION}") if (MSVC) cmake_minimum_required(VERSION 3.11) @@ -84,7 +88,7 @@ endif (USE_CUDA) # dmlc-core msvc_use_static_runtime() -add_subdirectory(${PROJECT_SOURCE_DIR}/dmlc-core) +add_subdirectory(${xgboost_SOURCE_DIR}/dmlc-core) set_target_properties(dmlc PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON @@ -105,7 +109,7 @@ endif(RABIT_MOCK) # Exports some R specific definitions and objects if (R_LIB) - add_subdirectory(${PROJECT_SOURCE_DIR}/R-package) + add_subdirectory(${xgboost_SOURCE_DIR}/R-package) endif (R_LIB) # core xgboost @@ -123,22 +127,23 @@ target_link_libraries(xgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE}) # This creates its own shared library `xgboost4j'. if (JVM_BINDINGS) - add_subdirectory(${PROJECT_SOURCE_DIR}/jvm-packages) + add_subdirectory(${xgboost_SOURCE_DIR}/jvm-packages) endif (JVM_BINDINGS) #-- End shared library #-- CLI for xgboost -add_executable(runxgboost ${PROJECT_SOURCE_DIR}/src/cli_main.cc ${XGBOOST_OBJ_SOURCES}) +add_executable(runxgboost ${xgboost_SOURCE_DIR}/src/cli_main.cc ${XGBOOST_OBJ_SOURCES}) # For cli_main.cc only if (USE_OPENMP) find_package(OpenMP REQUIRED) target_compile_options(runxgboost PRIVATE ${OpenMP_CXX_FLAGS}) endif (USE_OPENMP) + target_include_directories(runxgboost PRIVATE - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/dmlc-core/include - ${PROJECT_SOURCE_DIR}/rabit/include) + ${xgboost_SOURCE_DIR}/include + ${xgboost_SOURCE_DIR}/dmlc-core/include + ${xgboost_SOURCE_DIR}/rabit/include) target_link_libraries(runxgboost PRIVATE ${LINKED_LIBRARIES_PRIVATE}) set_target_properties( runxgboost PROPERTIES @@ -147,8 +152,8 @@ set_target_properties( CXX_STANDARD_REQUIRED ON) #-- End CLI for xgboost -set_output_directory(runxgboost ${PROJECT_SOURCE_DIR}) -set_output_directory(xgboost ${PROJECT_SOURCE_DIR}/lib) +set_output_directory(runxgboost ${xgboost_SOURCE_DIR}) +set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib) # Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names add_dependencies(xgboost runxgboost) @@ -205,21 +210,21 @@ install( if (GOOGLE_TEST) enable_testing() # Unittests. - add_subdirectory(${PROJECT_SOURCE_DIR}/tests/cpp) + add_subdirectory(${xgboost_SOURCE_DIR}/tests/cpp) add_test( NAME TestXGBoostLib COMMAND testxgboost - WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + WORKING_DIRECTORY ${xgboost_BINARY_DIR}) # CLI tests configure_file( - ${PROJECT_SOURCE_DIR}/tests/cli/machine.conf.in - ${PROJECT_BINARY_DIR}/tests/cli/machine.conf + ${xgboost_SOURCE_DIR}/tests/cli/machine.conf.in + ${xgboost_BINARY_DIR}/tests/cli/machine.conf @ONLY) add_test( NAME TestXGBoostCLI - COMMAND runxgboost ${PROJECT_BINARY_DIR}/tests/cli/machine.conf - WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + COMMAND runxgboost ${xgboost_BINARY_DIR}/tests/cli/machine.conf + WORKING_DIRECTORY ${xgboost_BINARY_DIR}) set_tests_properties(TestXGBoostCLI PROPERTIES PASS_REGULAR_EXPRESSION ".*test-rmse:0.087.*") diff --git a/Jenkinsfile b/Jenkinsfile index d6f7d7b80..718a70931 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -83,7 +83,6 @@ pipeline { 'test-python-gpu-cuda10.0': { TestPythonGPU(cuda_version: '10.0') }, 'test-python-gpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1') }, 'test-python-mgpu-cuda10.1': { TestPythonGPU(cuda_version: '10.1', multi_gpu: true) }, - 'test-cpp-rabit': {TestCppRabit()}, 'test-cpp-gpu': { TestCppGPU(cuda_version: '10.1') }, 'test-cpp-mgpu': { TestCppGPU(cuda_version: '10.1', multi_gpu: true) }, 'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '2.4.3') }, diff --git a/cmake/Doc.cmake b/cmake/Doc.cmake index 0122be12c..2ffa005ff 100644 --- a/cmake/Doc.cmake +++ b/cmake/Doc.cmake @@ -6,7 +6,7 @@ function (run_doxygen) endif (NOT DOXYGEN_DOT_FOUND) configure_file( - ${PROJECT_SOURCE_DIR}/doc/Doxyfile.in + ${xgboost_SOURCE_DIR}/doc/Doxyfile.in ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY) add_custom_target( doc_doxygen ALL COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 8922fd0f0..5097a5fd9 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -111,7 +111,7 @@ DESTINATION \"${build_dir}/bak\")") install(CODE "file(REMOVE_RECURSE \"${build_dir}/R-package\")") install( - DIRECTORY "${PROJECT_SOURCE_DIR}/R-package" + DIRECTORY "${xgboost_SOURCE_DIR}/R-package" DESTINATION "${build_dir}" REGEX "src/*" EXCLUDE REGEX "R-package/configure" EXCLUDE diff --git a/cmake/Version.cmake b/cmake/Version.cmake index 0159723e1..a7dfb1c68 100644 --- a/cmake/Version.cmake +++ b/cmake/Version.cmake @@ -5,6 +5,5 @@ function (write_version) ${xgboost_SOURCE_DIR}/include/xgboost/version_config.h @ONLY) configure_file( ${xgboost_SOURCE_DIR}/cmake/Python_version.in - ${xgboost_SOURCE_DIR}/python-package/xgboost/VERSION - ) + ${xgboost_SOURCE_DIR}/python-package/xgboost/VERSION) endfunction (write_version) diff --git a/cmake/modules/FindNVML.cmake b/cmake/modules/FindNVML.cmake new file mode 100644 index 000000000..a4bed0019 --- /dev/null +++ b/cmake/modules/FindNVML.cmake @@ -0,0 +1,23 @@ +if (NVML_LIBRARY) + unset(NVML_LIBRARY CACHE) +endif(NVML_LIBRARY) + +set(NVML_LIB_NAME nvml) + +find_path(NVML_INCLUDE_DIR + NAMES nvml.h + PATHS ${CUDA_HOME}/include ${CUDA_INCLUDE} /usr/local/cuda/include) + +find_library(NVML_LIBRARY + NAMES nvidia-ml) + +message(STATUS "Using nvml library: ${NVML_LIBRARY}") + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NVML DEFAULT_MSG + NVML_INCLUDE_DIR NVML_LIBRARY) + +mark_as_advanced( + NVML_INCLUDE_DIR + NVML_LIBRARY +) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 060117407..8387fcf98 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -513,7 +513,7 @@ class DMatrix(object): try: csr = scipy.sparse.csr_matrix(data) self._init_from_csr(csr) - except: + except Exception: raise TypeError('can not initialize DMatrix from' ' {}'.format(type(data).__name__)) @@ -577,9 +577,9 @@ class DMatrix(object): if len(mat.shape) != 2: raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ', mat.shape) - # flatten the array by rows and ensure it is float32. - # we try to avoid data copies if possible (reshape returns a view when possible - # and we explicitly tell np.array to try and avoid copying) + # flatten the array by rows and ensure it is float32. we try to avoid + # data copies if possible (reshape returns a view when possible and we + # explicitly tell np.array to try and avoid copying) data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32) handle = ctypes.c_void_p() missing = missing if missing is not None else np.nan @@ -1391,8 +1391,9 @@ class Booster(object): value of the prediction. Note the last row and column correspond to the bias term. validate_features : bool - When this is True, validate that the Booster's and data's feature_names are identical. - Otherwise, it is assumed that the feature_names are the same. + When this is True, validate that the Booster's and data's + feature_names are identical. Otherwise, it is assumed that the + feature_names are the same. Returns ------- @@ -1811,8 +1812,8 @@ class Booster(object): msg = 'feature_names mismatch: {0} {1}' if dat_missing: - msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) + - ' in input data') + msg += ('\nexpected ' + ', '.join( + str(s) for s in dat_missing) + ' in input data') if my_missing: msg += ('\ntraining data did not have the following fields: ' + @@ -1821,7 +1822,8 @@ class Booster(object): raise ValueError(msg.format(self.feature_names, data.feature_names)) - def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True): + def get_split_value_histogram(self, feature, fmap='', bins=None, + as_pandas=True): """Get split value histogram of a feature Parameters diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 5a465412f..3181c5ebd 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -55,10 +55,14 @@ def _start_tracker(host, n_workers): return env -def _assert_dask_installed(): +def _assert_dask_support(): if not DASK_INSTALLED: raise ImportError( 'Dask needs to be installed in order to use this module') + if platform.system() == 'Windows': + msg = 'Windows is not officially supported for dask/xgboost,' + msg += ' contribution are welcomed.' + logging.warning(msg) class RabitContext: @@ -96,6 +100,11 @@ def _xgb_get_client(client): return ret +def _get_client_workers(client): + workers = client.scheduler_info()['workers'] + return workers + + class DaskDMatrix: # pylint: disable=missing-docstring, too-many-instance-attributes '''DMatrix holding on references to Dask DataFrame or Dask Array. @@ -132,7 +141,7 @@ class DaskDMatrix: weight=None, feature_names=None, feature_types=None): - _assert_dask_installed() + _assert_dask_support() self._feature_names = feature_names self._feature_types = feature_types @@ -263,6 +272,17 @@ class DaskDMatrix: A DMatrix object. ''' + if worker.address not in set(self.worker_map.keys()): + msg = 'worker {address} has an empty DMatrix. ' \ + 'All workers associated with this DMatrix: {workers}'.format( + address=worker.address, + workers=set(self.worker_map.keys())) + logging.warning(msg) + d = DMatrix(numpy.empty((0, 0)), + feature_names=self._feature_names, + feature_types=self._feature_types) + return d + data, labels, weights = self.get_worker_parts(worker) data = concat(data) @@ -275,7 +295,6 @@ class DaskDMatrix: weights = concat(weights) else: weights = None - dmatrix = DMatrix(data, labels, weight=weights, @@ -342,35 +361,33 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): 'eval': {'logloss': ['0.480385', '0.357756']}}} ''' - _assert_dask_installed() - if platform.system() == 'Windows': - msg = 'Windows is not officially supported for dask/xgboost,' - msg += ' contribution are welcomed.' - logging.warning(msg) + _assert_dask_support() if 'evals_result' in kwargs.keys(): raise ValueError( 'evals_result is not supported in dask interface.', 'The evaluation history is returned as result of training.') - client = _xgb_get_client(client) + workers = list(_get_client_workers(client).keys()) - worker_map = dtrain.worker_map - rabit_args = _get_rabit_args(worker_map, client) + rabit_args = _get_rabit_args(workers, client) - def dispatched_train(worker_id): - '''Perform training on worker.''' - logging.info('Training on %d', worker_id) + def dispatched_train(worker_addr): + '''Perform training on a single worker.''' + logging.info('Training on %s', str(worker_addr)) worker = distributed_get_worker() - local_dtrain = dtrain.get_worker_data(worker) - - local_evals = [] - if evals: - for mat, name in evals: - local_mat = mat.get_worker_data(worker) - local_evals.append((local_mat, name)) - with RabitContext(rabit_args): + local_dtrain = dtrain.get_worker_data(worker) + + local_evals = [] + if evals: + for mat, name in evals: + if mat is dtrain: + local_evals.append((local_dtrain, name)) + continue + local_mat = mat.get_worker_data(worker) + local_evals.append((local_mat, name)) + local_history = {} local_param = params.copy() # just to be consistent bst = worker_train(params=local_param, @@ -380,14 +397,14 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): evals=local_evals, **kwargs) ret = {'booster': bst, 'history': local_history} - if rabit.get_rank() != 0: + if local_dtrain.num_row() == 0: ret = None return ret futures = client.map(dispatched_train, - range(len(worker_map)), + workers, pure=False, - workers=list(worker_map.keys())) + workers=workers) results = client.gather(futures) return list(filter(lambda ret: ret is not None, results))[0] @@ -414,7 +431,7 @@ def predict(client, model, data, *args): prediction: dask.array.Array ''' - _assert_dask_installed() + _assert_dask_support() if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -437,7 +454,8 @@ def predict(client, model, data, *args): local_x = data.get_worker_data(worker) with RabitContext(rabit_args): - local_predictions = booster.predict(data=local_x, *args) + local_predictions = booster.predict( + data=local_x, validate_features=local_x.num_row() != 0, *args) return local_predictions futures = client.map(dispatched_predict, @@ -563,7 +581,7 @@ class DaskXGBRegressor(DaskScikitLearnBase): sample_weights=None, eval_set=None, sample_weight_eval_set=None): - _assert_dask_installed() + _assert_dask_support() dtrain = DaskDMatrix(client=self.client, data=X, label=y, weight=sample_weights) params = self.get_xgb_params() @@ -579,7 +597,7 @@ class DaskXGBRegressor(DaskScikitLearnBase): return self def predict(self, data): # pylint: disable=arguments-differ - _assert_dask_installed() + _assert_dask_support() test_dmatrix = DaskDMatrix(client=self.client, data=data) pred_probs = predict(client=self.client, model=self.get_booster(), data=test_dmatrix) @@ -599,7 +617,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): sample_weights=None, eval_set=None, sample_weight_eval_set=None): - _assert_dask_installed() + _assert_dask_support() dtrain = DaskDMatrix(client=self.client, data=X, label=y, weight=sample_weights) params = self.get_xgb_params() @@ -626,7 +644,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): return self def predict(self, data): # pylint: disable=arguments-differ - _assert_dask_installed() + _assert_dask_support() test_dmatrix = DaskDMatrix(client=self.client, data=data) pred_probs = predict(client=self.client, model=self.get_booster(), data=test_dmatrix) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 017d28bb9..17d266df7 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -332,7 +332,7 @@ class RabitTracker(object): self.thread.start() def join(self): - while self.thread.isAlive(): + while self.thread.is_alive(): self.thread.join(100) def alive(self): diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9d753aadc..a140a195a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,5 @@ file(GLOB_RECURSE CPU_SOURCES *.cc *.h) -list(REMOVE_ITEM CPU_SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc) +list(REMOVE_ITEM CPU_SOURCES ${xgboost_SOURCE_DIR}/src/cli_main.cc) #-- Object library # Object library is necessary for jvm-package, which creates its own shared @@ -9,7 +9,7 @@ if (USE_CUDA) add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES}) target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_CUDA=1) - target_include_directories(objxgboost PRIVATE ${PROJECT_SOURCE_DIR}/cub/) + target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/) target_compile_options(objxgboost PRIVATE $<$:--expt-extended-lambda> $<$:--expt-relaxed-constexpr> @@ -43,9 +43,9 @@ endif (USE_CUDA) target_include_directories(objxgboost PRIVATE - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/dmlc-core/include - ${PROJECT_SOURCE_DIR}/rabit/include) + ${xgboost_SOURCE_DIR}/include + ${xgboost_SOURCE_DIR}/dmlc-core/include + ${xgboost_SOURCE_DIR}/rabit/include) target_compile_options(objxgboost PRIVATE $<$,$>:/MP> diff --git a/src/common/device_helpers.cu b/src/common/device_helpers.cu new file mode 100644 index 000000000..90b58649d --- /dev/null +++ b/src/common/device_helpers.cu @@ -0,0 +1,91 @@ +/*! + * Copyright 2017-2019 XGBoost contributors + * + * \brief Utilities for CUDA. + */ +#ifdef XGBOOST_USE_NCCL +#include +#endif // #ifdef XGBOOST_USE_NCCL +#include + +#include "device_helpers.cuh" + +namespace dh { + +#if __CUDACC_VER_MAJOR__ > 9 +constexpr std::size_t kUuidLength = + sizeof(std::declval().uuid) / sizeof(uint64_t); + +void GetCudaUUID(int world_size, int rank, int device_ord, + xgboost::common::Span uuid) { + cudaDeviceProp prob; + safe_cuda(cudaGetDeviceProperties(&prob, device_ord)); + std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); +} + +std::string PrintUUID(xgboost::common::Span uuid) { + std::stringstream ss; + for (auto v : uuid) { + ss << std::hex << v; + } + return ss.str(); +} + +#endif // __CUDACC_VER_MAJOR__ > 9 + +void AllReducer::Init(int _device_ordinal) { +#ifdef XGBOOST_USE_NCCL + LOG(DEBUG) << "Running nccl init on: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__; + + device_ordinal = _device_ordinal; + int32_t const rank = rabit::GetRank(); + +#if __CUDACC_VER_MAJOR__ > 9 + int32_t const world = rabit::GetWorldSize(); + + std::vector uuids(world * kUuidLength, 0); + auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; + auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); + GetCudaUUID(world, rank, device_ordinal, s_this_uuid); + + // No allgather yet. + rabit::Allreduce(uuids.data(), uuids.size()); + + std::vector> converted(world);; + size_t j = 0; + for (size_t i = 0; i < uuids.size(); i += kUuidLength) { + converted[j] = + xgboost::common::Span{uuids.data() + i, kUuidLength}; + j++; + } + + auto iter = std::unique(converted.begin(), converted.end()); + auto n_uniques = std::distance(converted.begin(), iter); + CHECK_EQ(n_uniques, world) + << "Multiple processes within communication group running on same CUDA " + << "device is not supported"; +#endif // __CUDACC_VER_MAJOR__ > 9 + + id = GetUniqueId(); + dh::safe_cuda(cudaSetDevice(device_ordinal)); + dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rank)); + safe_cuda(cudaStreamCreate(&stream)); + initialised_ = true; +#endif // XGBOOST_USE_NCCL +} + +AllReducer::~AllReducer() { +#ifdef XGBOOST_USE_NCCL + if (initialised_) { + dh::safe_cuda(cudaStreamDestroy(stream)); + ncclCommDestroy(comm); + } + if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { + LOG(CONSOLE) << "======== NCCL Statistics========"; + LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; + LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576; + } +#endif // XGBOOST_USE_NCCL +} + +} // namespace dh diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 1e87c9273..42f027d10 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -7,24 +7,25 @@ #include #include #include -#include + +#include #include +#include #include -#include "xgboost/host_device_vector.h" -#include "xgboost/span.h" - -#include "common.h" - #include -#include #include #include -#include #include #include #include #include + +#include "xgboost/logging.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/span.h" + +#include "common.h" #include "timer.h" #ifdef XGBOOST_USE_NCCL @@ -205,24 +206,53 @@ __global__ void LaunchNKernel(size_t begin, size_t end, L lambda) { } template __global__ void LaunchNKernel(int device_idx, size_t begin, size_t end, - L lambda) { + L lambda) { for (auto i : GridStrideRange(begin, end)) { lambda(i, device_idx); } } +/* \brief A wrapper around kernel launching syntax, used to guard against empty input. + * + * - nvcc fails to deduce template argument when kernel is a template accepting __device__ + * function as argument. Hence functions like `LaunchN` cannot use this wrapper. + * + * - With c++ initialization list `{}` syntax, you are forced to comply with the CUDA type + * spcification. + */ +class LaunchKernel { + size_t shmem_size_; + cudaStream_t stream_; + + dim3 grids_; + dim3 blocks_; + + public: + LaunchKernel(uint32_t _grids, uint32_t _blk, size_t _shmem=0, cudaStream_t _s=0) : + grids_{_grids, 1, 1}, blocks_{_blk, 1, 1}, shmem_size_{_shmem}, stream_{_s} {} + LaunchKernel(dim3 _grids, dim3 _blk, size_t _shmem=0, cudaStream_t _s=0) : + grids_{_grids}, blocks_{_blk}, shmem_size_{_shmem}, stream_{_s} {} + + template + void operator()(K kernel, Args... args) { + if (XGBOOST_EXPECT(grids_.x * grids_.y * grids_.z == 0, false)) { + LOG(DEBUG) << "Skipping empty CUDA kernel."; + return; + } + kernel<<>>(args...); // NOLINT + } +}; + template inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { if (n == 0) { return; } - safe_cuda(cudaSetDevice(device_idx)); - const int GRID_SIZE = static_cast(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS)); - LaunchNKernel<<>>(static_cast(0), - n, lambda); + LaunchNKernel<<>>( // NOLINT + static_cast(0), n, lambda); } // Default stream version @@ -301,6 +331,16 @@ inline detail::MemoryLogger &GlobalMemoryLogger() { return memory_logger; } +// dh::DebugSyncDevice(__FILE__, __LINE__); +inline void DebugSyncDevice(std::string file="", int32_t line = -1) { + if (file != "" && line != -1) { + auto rank = rabit::GetRank(); + LOG(DEBUG) << "R:" << rank << ": " << file << ":" << line; + } + safe_cuda(cudaDeviceSynchronize()); + safe_cuda(cudaGetLastError()); +} + namespace detail{ /** * \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose. @@ -763,7 +803,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, BLOCK_THREADS, segments, num_segments, count); LbsKernel - <<>>(tmp_tile_coordinates, + <<>>(tmp_tile_coordinates, // NOLINT segments + 1, f, num_segments); } @@ -963,7 +1003,6 @@ class SaveCudaContext { * streams. Must be initialised before use. If XGBoost is compiled without NCCL * this is a dummy class that will error if used with more than one GPU. */ - class AllReducer { bool initialised_; size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated @@ -986,31 +1025,9 @@ class AllReducer { * * \param device_ordinal The device ordinal. */ + void Init(int _device_ordinal); - void Init(int _device_ordinal) { -#ifdef XGBOOST_USE_NCCL - /** \brief this >monitor . init. */ - device_ordinal = _device_ordinal; - id = GetUniqueId(); - dh::safe_cuda(cudaSetDevice(device_ordinal)); - dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rabit::GetRank())); - safe_cuda(cudaStreamCreate(&stream)); - initialised_ = true; -#endif - } - ~AllReducer() { -#ifdef XGBOOST_USE_NCCL - if (initialised_) { - dh::safe_cuda(cudaStreamDestroy(stream)); - ncclCommDestroy(comm); - } - if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { - LOG(CONSOLE) << "======== NCCL Statistics========"; - LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; - LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_/1048576; - } -#endif - } + ~AllReducer(); /** * \brief Allreduce. Use in exactly the same way as NCCL but without needing diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 83e8c117b..56bd4865e 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -293,6 +293,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { void DenseCuts::Init (std::vector* in_sketchs, uint32_t max_num_bins) { + monitor_.Start(__func__); std::vector& sketchs = *in_sketchs; constexpr int kFactor = 8; // gather the histogram data @@ -332,6 +333,7 @@ void DenseCuts::Init CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back()); p_cuts_->cut_ptrs_.push_back(cut_size); } + monitor_.Stop(__func__); } void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index a10e1f7aa..6d52e03e1 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -252,8 +252,10 @@ class GPUSketcher { }); } else if (n_cuts_cur_[icol] > 0) { // if more elements than cuts: use binary search on cumulative weights - int block = 256; - FindCutsK<<>>( + uint32_t constexpr kBlockThreads = 256; + uint32_t const kGrids = common::DivRoundUp(n_cuts_cur_[icol], kBlockThreads); + dh::LaunchKernel {kGrids, kBlockThreads} ( + FindCutsK, cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(), weights2_.data().get(), @@ -403,7 +405,8 @@ class GPUSketcher { // NOTE: This will typically support ~ 4M features - 64K*64 dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), common::DivRoundUp(num_cols_, block3.y), 1); - UnpackFeaturesK<<>>( + dh::LaunchKernel {grid3, block3} ( + UnpackFeaturesK, fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr, row_ptrs_.data().get() + batch_row_begin, diff --git a/src/common/timer.cc b/src/common/timer.cc index 41aa6da26..d4cfbfcc7 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -13,6 +13,20 @@ namespace xgboost { namespace common { +void Monitor::Start(std::string const &name) { + if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + statistics_map[name].timer.Start(); + } +} + +void Monitor::Stop(const std::string &name) { + if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + auto &stats = statistics_map[name]; + stats.timer.Stop(); + stats.count++; + } +} + std::vector Monitor::CollectFromOtherRanks() const { // Since other nodes might have started timers that this one haven't, so // we can't simply call all reduce. diff --git a/src/common/timer.cu b/src/common/timer.cu new file mode 100644 index 000000000..5cb72ddba --- /dev/null +++ b/src/common/timer.cu @@ -0,0 +1,38 @@ +/*! + * Copyright by Contributors 2019 + */ +#if defined(XGBOOST_USE_NVTX) +#include +#endif // defined(XGBOOST_USE_NVTX) + +#include + +#include "xgboost/logging.h" +#include "device_helpers.cuh" +#include "timer.h" + +namespace xgboost { +namespace common { + +void Monitor::StartCuda(const std::string& name) { + if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + auto &stats = statistics_map[name]; + stats.timer.Start(); +#if defined(XGBOOST_USE_NVTX) + stats.nvtx_id = nvtxRangeStartA(name.c_str()); +#endif // defined(XGBOOST_USE_NVTX) + } +} + +void Monitor::StopCuda(const std::string& name) { + if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + auto &stats = statistics_map[name]; + stats.timer.Stop(); + stats.count++; +#if defined(XGBOOST_USE_NVTX) + nvtxRangeEnd(stats.nvtx_id); +#endif // defined(XGBOOST_USE_NVTX) + } +} +} // namespace common +} // namespace xgboost diff --git a/src/common/timer.h b/src/common/timer.h index a899e8798..86c3a3cef 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -10,10 +10,6 @@ #include #include -#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) -#include -#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) - namespace xgboost { namespace common { @@ -84,37 +80,10 @@ struct Monitor { void Print() const; void Init(std::string label) { this->label = label; } - void Start(const std::string &name) { - if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { - statistics_map[name].timer.Start(); - } - } - void Stop(const std::string &name) { - if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { - auto &stats = statistics_map[name]; - stats.timer.Stop(); - stats.count++; - } - } - void StartCuda(const std::string &name) { - if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { - auto &stats = statistics_map[name]; - stats.timer.Start(); -#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) - stats.nvtx_id = nvtxRangeStartA(name.c_str()); -#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) - } - } - void StopCuda(const std::string &name) { - if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { - auto &stats = statistics_map[name]; - stats.timer.Stop(); - stats.count++; -#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) - nvtxRangeEnd(stats.nvtx_id); -#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) - } - } + void Start(const std::string &name); + void Stop(const std::string &name); + void StartCuda(const std::string &name); + void StopCuda(const std::string &name); }; } // namespace common } // namespace xgboost diff --git a/src/common/transform.h b/src/common/transform.h index b62ac47cd..bcbdba22e 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -133,9 +133,12 @@ class Transform { size_t shard_size = range_size; Range shard_range {0, static_cast(shard_size)}; dh::safe_cuda(cudaSetDevice(device_)); - const int GRID_SIZE = + const int kGrids = static_cast(DivRoundUp(*(range_.end()), kBlockThreads)); - detail::LaunchCUDAKernel<<>>( + if (kGrids == 0) { + return; + } + detail::LaunchCUDAKernel<<>>( // NOLINT _func, shard_range, UnpackHDVOnDevice(_vectors)...); } #else diff --git a/src/data/data.cc b/src/data/data.cc index b4f42a260..21b098162 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -320,6 +320,32 @@ void DMatrix::SaveToLocalFile(const std::string& fname) { DMatrix* DMatrix::Create(std::unique_ptr>&& source, const std::string& cache_prefix) { if (cache_prefix.length() == 0) { + // FIXME(trivialfis): Currently distcol is broken so we here check for number of rows. + // If we bring back column split this check will break. + bool is_distributed { rabit::IsDistributed() }; + if (is_distributed) { + auto world_size = rabit::GetWorldSize(); + auto rank = rabit::GetRank(); + std::vector ncols(world_size, 0); + ncols[rank] = source->info.num_col_; + rabit::Allreduce(ncols.data(), ncols.size()); + auto max_cols = std::max_element(ncols.cbegin(), ncols.cend()); + auto max_ind = std::distance(ncols.cbegin(), max_cols); + // FIXME(trivialfis): This is a hack, we should store a reference to global shape if possible. + if (source->info.num_col_ == 0 && source->info.num_row_ == 0) { + LOG(WARNING) << "DMatrix at rank: " << rank << " worker is empty."; + source->info.num_col_ = *max_cols; + } + + // validate the number of columns across all workers. + for (size_t i = 0; i < ncols.size(); ++i) { + auto v = ncols[i]; + CHECK(v == 0 || v == *max_cols) + << "DMatrix at rank: " << i << " worker " + << "has different number of columns than rank: " << max_ind << " worker. " + << "(" << v << " vs. " << *max_cols << ")"; + } + } return new data::SimpleDMatrix(std::move(source)); } else { #if DMLC_ENABLE_STD_THREAD diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 81ace7505..b2a081e7e 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -99,13 +99,13 @@ EllpackInfo::EllpackInfo(int device, bool is_dense, size_t row_stride, const common::HistogramCuts& hmat, - dh::BulkAllocator& ba) + dh::BulkAllocator* ba) : is_dense(is_dense), row_stride(row_stride), n_bins(hmat.Ptrs().back()) { - ba.Allocate(device, - &feature_segments, hmat.Ptrs().size(), - &gidx_fvalue_map, hmat.Values().size(), - &min_fvalue, hmat.MinValues().size()); + ba->Allocate(device, + &feature_segments, hmat.Ptrs().size(), + &gidx_fvalue_map, hmat.Values().size(), + &min_fvalue, hmat.MinValues().size()); dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values()); dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues()); dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs()); @@ -116,7 +116,7 @@ void EllpackPageImpl::InitInfo(int device, bool is_dense, size_t row_stride, const common::HistogramCuts& hmat) { - matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, ba_); + matrix.info = EllpackInfo(device, is_dense, row_stride, hmat, &ba_); } // Initialize the buffer to stored compressed features. @@ -189,7 +189,8 @@ void EllpackPageImpl::CreateHistIndices(int device, const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), common::DivRoundUp(row_stride, block3.y), 1); - CompressBinEllpackKernel<<>>( + dh::LaunchKernel {grid3, block3} ( + CompressBinEllpackKernel, common::CompressedBufferWriter(num_symbols), gidx_buffer.data(), row_ptrs.data().get(), diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 03a6a07ed..1b38fcfa6 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -70,7 +70,7 @@ struct EllpackInfo { bool is_dense, size_t row_stride, const common::HistogramCuts& hmat, - dh::BulkAllocator& ba); + dh::BulkAllocator* ba); }; /** \brief Struct for accessing and manipulating an ellpack matrix on the diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index fcb0936b7..e2456d9a4 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -85,7 +85,7 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, monitor_.StopCuda("Quantiles"); monitor_.StartCuda("CreateEllpackInfo"); - ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, ba_); + ellpack_info_ = EllpackInfo(device_, dmat->IsDense(), row_stride, hmat, &ba_); monitor_.StopCuda("CreateEllpackInfo"); monitor_.StartCuda("WriteEllpackPages"); diff --git a/src/data/simple_csr_source.cu b/src/data/simple_csr_source.cu index 93a9462bf..bc7ce1cd5 100644 --- a/src/data/simple_csr_source.cu +++ b/src/data/simple_csr_source.cu @@ -101,7 +101,7 @@ void CountValid(std::vector const& j_columns, uint32_t column_id, HostDeviceVector* out_offset, dh::caching_device_vector* out_d_flag, uint32_t* out_n_rows) { - int32_t constexpr kThreads = 256; + uint32_t constexpr kThreads = 256; auto const& j_column = j_columns[column_id]; auto const& column_obj = get(j_column); Columnar foreign_column = ArrayInterfaceHandler::ExtractArray(column_obj); @@ -123,8 +123,9 @@ void CountValid(std::vector const& j_columns, uint32_t column_id, common::Span s_offsets = out_offset->DeviceSpan(); - int32_t const kBlocks = common::DivRoundUp(n_rows, kThreads); - CountValidKernel<<>>( + uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads); + dh::LaunchKernel {kBlocks, kThreads} ( + CountValidKernel, foreign_column, has_missing, missing, out_d_flag->data().get(), s_offsets); @@ -135,13 +136,15 @@ template void CreateCSR(std::vector const& j_columns, uint32_t column_id, uint32_t n_rows, bool has_missing, float missing, dh::device_vector* tmp_offset, common::Span s_data) { - int32_t constexpr kThreads = 256; + uint32_t constexpr kThreads = 256; auto const& j_column = j_columns[column_id]; auto const& column_obj = get(j_column); Columnar foreign_column = ArrayInterfaceHandler::ExtractArray(column_obj); - int32_t kBlocks = common::DivRoundUp(n_rows, kThreads); - CreateCSRKernel<<>>(foreign_column, column_id, has_missing, missing, - dh::ToSpan(*tmp_offset), s_data); + uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads); + dh::LaunchKernel {kBlocks, kThreads} ( + CreateCSRKernel, + foreign_column, column_id, has_missing, missing, + dh::ToSpan(*tmp_offset), s_data); } void SimpleCSRSource::FromDeviceColumnar(std::vector const& columns, diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 2de7bc72f..7fcc7c679 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -246,6 +246,14 @@ class GBTree : public GradientBooster { std::unique_ptr const& GetPredictor(HostDeviceVector const* out_pred = nullptr, DMatrix* f_dmat = nullptr) const { CHECK(configured_); + auto on_device = f_dmat && (*(f_dmat->GetBatches().begin())).data.DeviceCanRead(); +#if defined(XGBOOST_USE_CUDA) + // Use GPU Predictor if data is already on device. + if (!specified_predictor_ && on_device) { + CHECK(gpu_predictor_); + return gpu_predictor_; + } +#endif // defined(XGBOOST_USE_CUDA) // GPU_Hist by default has prediction cache calculated from quantile values, so GPU // Predictor is not used for training dataset. But when XGBoost performs continue // training with an existing model, the prediction cache is not availbale and number @@ -256,7 +264,7 @@ class GBTree : public GradientBooster { (model_.param.num_trees != 0) && // FIXME(trivialfis): Implement a better method for testing whether data is on // device after DMatrix refactoring is done. - (f_dmat && !((*(f_dmat->GetBatches().begin())).data.DeviceCanRead()))) { + !on_device) { return cpu_predictor_; } if (tparam_.predictor == "cpu_predictor") { diff --git a/src/learner.cc b/src/learner.cc index 97467269d..21c00046b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -630,7 +630,7 @@ class LearnerImpl : public Learner { CHECK_LE(num_col, static_cast(std::numeric_limits::max())) << "Unfortunately, XGBoost does not support data matrices with " << std::numeric_limits::max() << " features or greater"; - num_feature = std::max(num_feature, static_cast(num_col)); + num_feature = std::max(num_feature, static_cast(num_col)); } // run allreduce on num_feature to find the maximum value rabit::Allreduce(&num_feature, 1, nullptr, nullptr, "num_feature"); diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 523c3d72e..4d26ceba2 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -3,6 +3,8 @@ * \file elementwise_metric.cc * \brief evaluation metrics for elementwise binary or regression. * \author Kailong Chen, Tianqi Chen + * + * The expressions like wsum == 0 ? esum : esum / wsum is used to handle empty dataset. */ #include #include @@ -142,7 +144,7 @@ struct EvalRowRMSE { return diff * diff; } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return std::sqrt(esum / wsum); + return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum); } }; @@ -150,12 +152,13 @@ struct EvalRowRMSLE { char const* Name() const { return "rmsle"; } + XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const { bst_float diff = std::log1p(label) - std::log1p(pred); return diff * diff; } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return std::sqrt(esum / wsum); + return wsum == 0 ? std::sqrt(esum) : std::sqrt(esum / wsum); } }; @@ -168,7 +171,7 @@ struct EvalRowMAE { return std::abs(label - pred); } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; + return wsum == 0 ? esum : esum / wsum; } }; @@ -190,7 +193,7 @@ struct EvalRowLogLoss { } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; + return wsum == 0 ? esum : esum / wsum; } }; @@ -225,7 +228,7 @@ struct EvalError { } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; + return wsum == 0 ? esum : esum / wsum; } private: @@ -245,7 +248,7 @@ struct EvalPoissonNegLogLik { } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; + return wsum == 0 ? esum : esum / wsum; } }; @@ -278,7 +281,7 @@ struct EvalGammaNLogLik { return -((y * theta - b) / a + c); } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; + return wsum == 0 ? esum : esum / wsum; } }; @@ -304,7 +307,7 @@ struct EvalTweedieNLogLik { return -a + b; } static bst_float GetFinal(bst_float esum, bst_float wsum) { - return esum / wsum; + return wsum == 0 ? esum : esum / wsum; } protected: @@ -323,7 +326,9 @@ struct EvalEWiseBase : public Metric { bst_float Eval(const HostDeviceVector& preds, const MetaInfo& info, bool distributed) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + if (info.labels_.Size() == 0) { + LOG(WARNING) << "label set is empty"; + } CHECK_EQ(preds.Size(), info.labels_.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; @@ -333,6 +338,7 @@ struct EvalEWiseBase : public Metric { reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds); double dat[2] { result.Residue(), result.Weights() }; + if (distributed) { rabit::Allreduce(dat, 2); } diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index f04471d1d..0be272230 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -54,7 +54,9 @@ class RegLossObj : public ObjFunction { const MetaInfo &info, int iter, HostDeviceVector* out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + if (info.labels_.Size() == 0U) { + LOG(WARNING) << "Label set is empty."; + } CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided" << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index a20430b25..4a1dccb8b 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -60,6 +60,9 @@ class CPUPredictor : public Predictor { constexpr int kUnroll = 8; const auto nsize = static_cast(batch.Size()); const bst_omp_uint rest = nsize % kUnroll; + // Pull to host before entering omp block, as this is not thread safe. + batch.data.HostVector(); + batch.offset.HostVector(); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { const int tid = omp_get_thread_num(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 82b922224..e80ffe457 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -225,12 +225,12 @@ class GPUPredictor : public xgboost::Predictor { HostDeviceVector* predictions, size_t batch_offset) { dh::safe_cuda(cudaSetDevice(device_)); - const int BLOCK_THREADS = 128; + const uint32_t BLOCK_THREADS = 128; size_t num_rows = batch.Size(); - const int GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); + auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - int shared_memory_bytes = static_cast - (sizeof(float) * num_features * BLOCK_THREADS); + auto shared_memory_bytes = + static_cast(sizeof(float) * num_features * BLOCK_THREADS); bool use_shared = true; if (shared_memory_bytes > max_shared_memory_bytes_) { shared_memory_bytes = 0; @@ -238,11 +238,12 @@ class GPUPredictor : public xgboost::Predictor { } size_t entry_start = 0; - PredictKernel<<>> - (dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), - dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(), - batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows, - entry_start, use_shared, this->num_group_); + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( + PredictKernel, + dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), + dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(), + batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows, + entry_start, use_shared, this->num_group_); } void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { diff --git a/src/tree/constraints.cu b/src/tree/constraints.cu index 20f800e35..472c2f0ec 100644 --- a/src/tree/constraints.cu +++ b/src/tree/constraints.cu @@ -165,10 +165,11 @@ __global__ void ClearBuffersKernel( void FeatureInteractionConstraint::ClearBuffers() { CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size()); CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size()); - int constexpr kBlockThreads = 256; - const int n_grids = static_cast( + uint32_t constexpr kBlockThreads = 256; + auto const n_grids = static_cast( common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); - ClearBuffersKernel<<>>( + dh::LaunchKernel {n_grids, kBlockThreads} ( + ClearBuffersKernel, output_buffer_bits_, input_buffer_bits_); } @@ -222,12 +223,14 @@ common::Span FeatureInteractionConstraint::Query( LBitField64 node_constraints = s_node_constraints_[nid]; CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size()); - int constexpr kBlockThreads = 256; - const int n_grids = static_cast( + uint32_t constexpr kBlockThreads = 256; + auto n_grids = static_cast( common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); - SetInputBufferKernel<<>>(feature_list, input_buffer_bits_); - - QueryFeatureListKernel<<>>( + dh::LaunchKernel {n_grids, kBlockThreads} ( + SetInputBufferKernel, + feature_list, input_buffer_bits_); + dh::LaunchKernel {n_grids, kBlockThreads} ( + QueryFeatureListKernel, node_constraints, input_buffer_bits_, output_buffer_bits_); thrust::counting_iterator begin(0); @@ -327,20 +330,20 @@ void FeatureInteractionConstraint::Split( dim3 const block3(16, 64, 1); dim3 const grid3(common::DivRoundUp(n_sets_, 16), common::DivRoundUp(s_fconstraints_.size(), 64)); - RestoreFeatureListFromSetsKernel<<>> - (feature_buffer_, - feature_id, - s_fconstraints_, - s_fconstraints_ptr_, - s_sets_, - s_sets_ptr_); + dh::LaunchKernel {grid3, block3} ( + RestoreFeatureListFromSetsKernel, + feature_buffer_, feature_id, + s_fconstraints_, s_fconstraints_ptr_, + s_sets_, s_sets_ptr_); - int constexpr kBlockThreads = 256; - const int n_grids = static_cast(common::DivRoundUp(node.Size(), kBlockThreads)); - InteractionConstraintSplitKernel<<>> - (feature_buffer_, - feature_id, - node, left, right); + uint32_t constexpr kBlockThreads = 256; + auto n_grids = static_cast(common::DivRoundUp(node.Size(), kBlockThreads)); + + dh::LaunchKernel {n_grids, kBlockThreads} ( + InteractionConstraintSplitKernel, + feature_buffer_, + feature_id, + node, left, right); } } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index d239986a4..83ccb3b1d 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -603,12 +603,12 @@ struct GPUHistMakerDevice { } // One block for each feature - int constexpr kBlockThreads = 256; - EvaluateSplitKernel - <<>>( - hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix, - gpu_param, d_split_candidates, node_value_constraints[nidx], - monotone_constraints); + uint32_t constexpr kBlockThreads = 256; + dh::LaunchKernel {uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]} ( + EvaluateSplitKernel, + hist.GetNodeHistogram(nidx), d_feature_set, node, page->matrix, + gpu_param, d_split_candidates, node_value_constraints[nidx], + monotone_constraints); // Reduce over features to find best feature auto d_cub_memory = @@ -638,14 +638,12 @@ struct GPUHistMakerDevice { use_shared_memory_histograms ? sizeof(GradientSumT) * page->matrix.BinCount() : 0; - const int items_per_thread = 8; - const int block_threads = 256; - const int grid_size = static_cast( + uint32_t items_per_thread = 8; + uint32_t block_threads = 256; + auto grid_size = static_cast( common::DivRoundUp(n_elements, items_per_thread * block_threads)); - if (grid_size <= 0) { - return; - } - SharedMemHistKernel<<>>( + dh::LaunchKernel {grid_size, block_threads, smem_size} ( + SharedMemHistKernel, page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements, use_shared_memory_histograms); } @@ -886,6 +884,7 @@ struct GPUHistMakerDevice { monitor.StartCuda("InitRoot"); this->InitRoot(p_tree, gpair_all, reducer, p_fmat->Info().num_col_); monitor.StopCuda("InitRoot"); + auto timestamp = qexpand->size(); auto num_leaves = 1; @@ -895,7 +894,6 @@ struct GPUHistMakerDevice { if (!candidate.IsValid(param, num_leaves)) { continue; } - this->ApplySplit(candidate, p_tree); num_leaves++; @@ -996,18 +994,22 @@ class GPUHistMakerSpecialised { try { for (xgboost::RegTree* tree : trees) { this->UpdateTree(gpair, dmat, tree); + + if (hist_maker_param_.debug_synchronize) { + this->CheckTreesSynchronized(tree); + } } dh::safe_cuda(cudaGetLastError()); } catch (const std::exception& e) { LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl; } + param_.learning_rate = lr; monitor_.StopCuda("Update"); } void InitDataOnce(DMatrix* dmat) { info_ = &dmat->Info(); - reducer_.Init({device_}); // Synchronise the column sampling seed @@ -1048,20 +1050,18 @@ class GPUHistMakerSpecialised { } // Only call this method for testing - void CheckTreesSynchronized(const std::vector& local_trees) const { + void CheckTreesSynchronized(RegTree* local_tree) const { std::string s_model; common::MemoryBufferStream fs(&s_model); int rank = rabit::GetRank(); if (rank == 0) { - local_trees.front().SaveModel(&fs); + local_tree->SaveModel(&fs); } fs.Seek(0); rabit::Broadcast(&s_model, 0); - RegTree reference_tree{}; + RegTree reference_tree {}; // rank 0 tree reference_tree.LoadModel(&fs); - for (const auto& tree : local_trees) { - CHECK(tree == reference_tree); - } + CHECK(*local_tree == reference_tree); } void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, diff --git a/tests/ci_build/Dockerfile.cudf b/tests/ci_build/Dockerfile.cudf index 02d6c9325..bcea72eff 100644 --- a/tests/ci_build/Dockerfile.cudf +++ b/tests/ci_build/Dockerfile.cudf @@ -18,7 +18,7 @@ ENV PATH=/opt/python/bin:$PATH # Create new Conda environment with cuDF and dask RUN \ conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \ - cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask + cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda # Install other Python packages RUN \ diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 533110019..dc92906e9 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -17,7 +17,8 @@ ENV PATH=/opt/python/bin:$PATH # Install Python packages RUN \ pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \ - pip install "dask[complete]" + pip install "dask[complete]" && \ + conda install -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda dask-cuda ENV GOSU_VERSION 1.10 diff --git a/tests/ci_build/Dockerfile.gpu_build b/tests/ci_build/Dockerfile.gpu_build index ca5472f6b..052408f61 100644 --- a/tests/ci_build/Dockerfile.gpu_build +++ b/tests/ci_build/Dockerfile.gpu_build @@ -21,18 +21,12 @@ RUN \ # NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) RUN \ export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \ - if [ "${CUDA_SHORT}" != "10.0" ] && [ "${CUDA_SHORT}" != "10.1" ]; then \ - wget https://developer.download.nvidia.com/compute/redist/nccl/v2.2/nccl_2.2.13-1%2Bcuda${CUDA_SHORT}_x86_64.txz && \ - tar xf "nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64.txz" && \ - cp nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64/include/nccl.h /usr/include && \ - cp nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64/lib/* /usr/lib && \ - rm -f nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64.txz && \ - rm -r nccl_2.2.13-1+cuda${CUDA_SHORT}_x86_64; else \ + export NCCL_VERSION=2.4.8-1 && \ wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ yum -y update && \ - yum install -y libnccl-2.4.2-1+cuda${CUDA_SHORT} libnccl-devel-2.4.2-1+cuda${CUDA_SHORT} libnccl-static-2.4.2-1+cuda${CUDA_SHORT} && \ - rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm; fi + yum install -y libnccl-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-devel-${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-static-${NCCL_VERSION}+cuda${CUDA_SHORT} && \ + rm -f nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm; ENV PATH=/opt/python/bin:$PATH ENV CC=/opt/rh/devtoolset-4/root/usr/bin/gcc diff --git a/tests/ci_build/test_python.sh b/tests/ci_build/test_python.sh index dd1722163..db99db9d2 100755 --- a/tests/ci_build/test_python.sh +++ b/tests/ci_build/test_python.sh @@ -33,11 +33,13 @@ case "$suite" in pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu cd tests/distributed ./runtests-gpu.sh + cd - + pytest -v -s --fulltrace -m "mgpu" tests/python-gpu/test_gpu_with_dask.py ;; cudf) source activate cudf_test - python -m pytest -v -s --fulltrace tests/python-gpu/test_from_columnar.py tests/python-gpu/test_gpu_with_dask.py + pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py ;; cpu) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 319abe004..09a40deee 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -19,7 +19,7 @@ if (USE_CUDA) # OpenMP is mandatory for CUDA find_package(OpenMP REQUIRED) target_include_directories(testxgboost PRIVATE - ${PROJECT_SOURCE_DIR}/cub/) + ${xgboost_SOURCE_DIR}/cub/) target_compile_options(testxgboost PRIVATE $<$:--expt-extended-lambda> $<$:--expt-relaxed-constexpr> @@ -48,9 +48,9 @@ endif (USE_CUDA) target_include_directories(testxgboost PRIVATE ${GTEST_INCLUDE_DIRS} - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/dmlc-core/include - ${PROJECT_SOURCE_DIR}/rabit/include) + ${xgboost_SOURCE_DIR}/include + ${xgboost_SOURCE_DIR}/dmlc-core/include + ${xgboost_SOURCE_DIR}/rabit/include) set_target_properties( testxgboost PROPERTIES CXX_STANDARD 11 @@ -67,7 +67,7 @@ target_compile_definitions(testxgboost PRIVATE ${XGBOOST_DEFINITIONS}) if (USE_OPENMP) target_compile_options(testxgboost PRIVATE $<$:${OpenMP_CXX_FLAGS}>) endif (USE_OPENMP) -set_output_directory(testxgboost ${PROJECT_BINARY_DIR}) +set_output_directory(testxgboost ${xgboost_BINARY_DIR}) # This grouping organises source files nicely in visual studio auto_source_group("${TEST_SOURCES}") diff --git a/tests/distributed/distributed_gpu.py b/tests/distributed/distributed_gpu.py index 4c47f54a3..f30e39b1b 100644 --- a/tests/distributed/distributed_gpu.py +++ b/tests/distributed/distributed_gpu.py @@ -2,6 +2,7 @@ import sys import time import xgboost as xgb +import os def run_test(name, params_fun): @@ -48,6 +49,9 @@ def run_test(name, params_fun): xgb.rabit.finalize() + if os.path.exists(model_name): + os.remove(model_name) + base_params = { 'tree_method': 'gpu_hist', @@ -81,7 +85,5 @@ def wrap_rf(params_fun): params_rf_1x4 = wrap_rf(params_basic_1x4) - - test_name = sys.argv[1] run_test(test_name, globals()['params_%s' % test_name]) diff --git a/tests/distributed/runtests-gpu.sh b/tests/distributed/runtests-gpu.sh index e942efd25..cc2d23cec 100755 --- a/tests/distributed/runtests-gpu.sh +++ b/tests/distributed/runtests-gpu.sh @@ -6,7 +6,7 @@ export DMLC_SUBMIT_CLUSTER=local submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit" echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" -$submit --num-workers=4 python distributed_gpu.py basic_1x4 || exit 1 +$submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py basic_1x4 || exit 1 echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" -$submit --num-workers=4 python distributed_gpu.py rf_1x4 || exit 1 +$submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py rf_1x4 || exit 1 diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 000000000..80c6579a8 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + mgpu: Mark a test that requires multiple GPUs to run. \ No newline at end of file diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index fc21293ea..91b9e50d2 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -2,6 +2,7 @@ import numpy as np import sys import unittest import pytest +import xgboost sys.path.append("tests/python") from regression_test_utilities import run_suite, parameter_combinations, \ @@ -21,7 +22,8 @@ datasets = ["Boston", "Cancer", "Digits", "Sparse regression", class TestGPU(unittest.TestCase): def test_gpu_hist(self): - test_param = parameter_combinations({'gpu_id': [0], 'max_depth': [2, 8], + test_param = parameter_combinations({'gpu_id': [0], + 'max_depth': [2, 8], 'max_leaves': [255, 4], 'max_bin': [2, 256], 'grow_policy': ['lossguide']}) @@ -36,6 +38,31 @@ class TestGPU(unittest.TestCase): cpu_results = run_suite(param, select_datasets=datasets) assert_gpu_results(cpu_results, gpu_results) + def test_with_empty_dmatrix(self): + # FIXME(trivialfis): This should be done with all updaters + kRows = 0 + kCols = 100 + + X = np.empty((kRows, kCols)) + y = np.empty((kRows)) + + dtrain = xgboost.DMatrix(X, y) + + bst = xgboost.train({'verbosity': 2, + 'tree_method': 'gpu_hist', + 'gpu_id': 0}, + dtrain, + verbose_eval=True, + num_boost_round=6, + evals=[(dtrain, 'Train')]) + + kRows = 100 + X = np.random.randn(kRows, kCols) + + dtest = xgboost.DMatrix(X) + predictions = bst.predict(dtest) + np.testing.assert_allclose(predictions, 0.5, 1e-6) + @pytest.mark.mgpu def test_specified_gpu_id_gpu_update(self): variable_param = {'gpu_id': [1], diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index c6a3f8407..8bdeb8299 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -1,45 +1,94 @@ import sys import pytest +import numpy as np +import unittest if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) try: - from distributed.utils_test import client, loop, cluster_fixture import dask.dataframe as dd from xgboost import dask as dxgb + from dask_cuda import LocalCUDACluster + from dask.distributed import Client import cudf except ImportError: - client = None - loop = None - cluster_fixture = None pass sys.path.append("tests/python") -from test_with_dask import generate_array -import testing as tm +from test_with_dask import generate_array # noqa +import testing as tm # noqa -@pytest.mark.skipif(**tm.no_dask()) -@pytest.mark.skipif(**tm.no_cudf()) -@pytest.mark.skipif(**tm.no_dask_cudf()) -def test_dask_dataframe(client): - X, y = generate_array() +class TestDistributedGPU(unittest.TestCase): + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_cudf()) + @pytest.mark.skipif(**tm.no_dask_cudf()) + @pytest.mark.skipif(**tm.no_dask_cuda()) + def test_dask_dataframe(self): + with LocalCUDACluster() as cluster: + with Client(cluster) as client: + X, y = generate_array() - X = dd.from_dask_array(X) - y = dd.from_dask_array(y) + X = dd.from_dask_array(X) + y = dd.from_dask_array(y) - X = X.map_partitions(cudf.from_pandas) - y = y.map_partitions(cudf.from_pandas) + X = X.map_partitions(cudf.from_pandas) + y = y.map_partitions(cudf.from_pandas) - dtrain = dxgb.DaskDMatrix(client, X, y) - out = dxgb.train(client, {'tree_method': 'gpu_hist'}, - dtrain=dtrain, - evals=[(dtrain, 'X')], - num_boost_round=2) + dtrain = dxgb.DaskDMatrix(client, X, y) + out = dxgb.train(client, {'tree_method': 'gpu_hist'}, + dtrain=dtrain, + evals=[(dtrain, 'X')], + num_boost_round=2) - assert isinstance(out['booster'], dxgb.Booster) - assert len(out['history']['X']['rmse']) == 2 + assert isinstance(out['booster'], dxgb.Booster) + assert len(out['history']['X']['rmse']) == 2 - predictions = dxgb.predict(out, dtrain) - predictions = predictions.compute() + # FIXME(trivialfis): Re-enable this after #5003 is fixed + # predictions = dxgb.predict(client, out, dtrain).compute() + # assert isinstance(predictions, np.ndarray) + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_dask_cuda()) + @pytest.mark.mgpu + def test_empty_dmatrix(self): + + def _check_outputs(out, predictions): + assert isinstance(out['booster'], dxgb.Booster) + assert len(out['history']['validation']['rmse']) == 2 + assert isinstance(predictions, np.ndarray) + assert predictions.shape[0] == 1 + + parameters = {'tree_method': 'gpu_hist', 'verbosity': 3, + 'debug_synchronize': True} + + with LocalCUDACluster() as cluster: + with Client(cluster) as client: + kRows, kCols = 1, 97 + X = dd.from_array(np.random.randn(kRows, kCols)) + y = dd.from_array(np.random.rand(kRows)) + dtrain = dxgb.DaskDMatrix(client, X, y) + + out = dxgb.train(client, parameters, + dtrain=dtrain, + evals=[(dtrain, 'validation')], + num_boost_round=2) + predictions = dxgb.predict(client=client, model=out, + data=dtrain).compute() + _check_outputs(out, predictions) + + # train has more rows than evals + valid = dtrain + kRows += 1 + X = dd.from_array(np.random.randn(kRows, kCols)) + y = dd.from_array(np.random.rand(kRows)) + dtrain = dxgb.DaskDMatrix(client, X, y) + + out = dxgb.train(client, parameters, + dtrain=dtrain, + evals=[(valid, 'validation')], + num_boost_round=2) + predictions = dxgb.predict(client=client, model=out, + data=valid).compute() + _check_outputs(out, predictions) diff --git a/tests/python/regression_test_utilities.py b/tests/python/regression_test_utilities.py index 2cf31c691..46c240b2f 100644 --- a/tests/python/regression_test_utilities.py +++ b/tests/python/regression_test_utilities.py @@ -67,7 +67,8 @@ def get_weights_regression(min_weight, max_weight): n = 10000 sparsity = 0.25 X, y = datasets.make_regression(n, random_state=rng) - X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X]) + X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x + for x in x_row] for x_row in X]) w = np.array([rng.uniform(min_weight, max_weight) for i in range(n)]) return X, y, w diff --git a/tests/python/testing.py b/tests/python/testing.py index 6073d381b..99747e04a 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -34,6 +34,15 @@ def no_matplotlib(): 'reason': reason} +def no_dask_cuda(): + reason = 'dask_cuda is not installed.' + try: + import dask_cuda as _ # noqa + return {'condition': False, 'reason': reason} + except ImportError: + return {'condition': True, 'reason': reason} + + def no_cudf(): return {'condition': not CUDF_INSTALLED, 'reason': 'CUDF is not installed'} diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 20216b4cd..6f6fbb301 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -34,6 +34,13 @@ fi if [ ${TASK} == "cmake_test" ]; then set -e + + if grep -n -R '<<<.*>>>\(.*\)' src include | grep --invert "NOLINT"; then + echo 'Do not use raw CUDA execution configuration syntax with <<>>.' \ + 'try `dh::LaunchKernel`' + exit -1 + fi + # Build/test rm -rf build mkdir build && cd build