From c1104f7d0afdaedd710014dfdc9f54e89512eb3a Mon Sep 17 00:00:00 2001 From: PSEUDOTENSOR / Jonathan McKinney Date: Wed, 9 Aug 2017 21:07:07 -0700 Subject: [PATCH] [GPU-Plugin] Add throw of asserts and added compute compatibility error check. (#2565) * [GPU-Plugin] Added compute compatibility error check, added verbose timing --- plugin/updater_gpu/src/device_helpers.cuh | 10 ++- plugin/updater_gpu/src/gpu_hist_builder.cu | 89 ++++++++++++++++++-- plugin/updater_gpu/src/gpu_hist_builder.cuh | 7 ++ plugin/updater_gpu/test/python/test.py | 13 +++ plugin/updater_gpu/test/python/test_large.py | 43 +++++++--- 5 files changed, 142 insertions(+), 20 deletions(-) diff --git a/plugin/updater_gpu/src/device_helpers.cuh b/plugin/updater_gpu/src/device_helpers.cuh index f734e9999..eedfdea48 100644 --- a/plugin/updater_gpu/src/device_helpers.cuh +++ b/plugin/updater_gpu/src/device_helpers.cuh @@ -63,15 +63,23 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file, #define gpuErrchk(ans) \ { gpuAssert((ans), __FILE__, __LINE__); } + inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { if (code != cudaSuccess) { fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); - if (abort) exit(code); + if (abort){ + std::stringstream ss; + ss << file << "(" << line << ")"; + std::string file_and_line; + ss >> file_and_line; + throw thrust::system_error(code, thrust::cuda_category(), file_and_line); + } } } + inline int n_visible_devices() { int n_visgpus = 0; diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cu b/plugin/updater_gpu/src/gpu_hist_builder.cu index 47e71c030..5a89a15b8 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cu +++ b/plugin/updater_gpu/src/gpu_hist_builder.cu @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -130,7 +132,7 @@ void GPUHistBuilder::Init(const TrainParam& param) { if (!param.silent) { size_t free_memory = dh::available_memory(device_idx); const int mb_size = 1048576; - LOG(CONSOLE) << "Device: [" << device_idx << "] " + LOG(CONSOLE) << "[GPU Plug-in] Device: [" << device_idx << "] " << dh::device_name(device_idx) << " with " << free_memory / mb_size << " MB available device memory."; } @@ -139,12 +141,19 @@ void GPUHistBuilder::Init(const TrainParam& param) { void GPUHistBuilder::InitData(const std::vector& gpair, DMatrix& fmat, // NOLINT const RegTree& tree) { + dh::Timer time1; // set member num_rows and n_devices for rest of GPUHistBuilder members info = &fmat.info(); num_rows = info->num_row; n_devices = dh::n_devices(param.n_gpus, num_rows); if (!initialised) { + // reset static timers used across iterations + cpu_init_time = 0; + gpu_init_time = 0; + cpu_time.reset(); + gpu_time = 0; + // set dList member dList.resize(n_devices); for (int d_idx = 0; d_idx < n_devices; ++d_idx) { @@ -176,7 +185,13 @@ void GPUHistBuilder::InitData(const std::vector& gpair, dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev)); // printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev, // prop.pciBusID, prop.name); - fflush(stdout); + // cudaDriverGetVersion(&driverVersion); + // cudaRuntimeGetVersion(&runtimeVersion); + std::ostringstream oss; + oss << "CUDA Capability Major/Minor version number: " + << prop.major << "." << prop.minor << " is insufficient. Need >=3.5."; + int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5; + CHECK(failed == 0) << oss.str(); } // local find_split group of comms for each case of reduced number of GPUs @@ -199,9 +214,40 @@ void GPUHistBuilder::InitData(const std::vector& gpair, "block. Try setting 'tree_method' " "parameter to 'exact'"; is_dense = info->num_nonzero == info->num_col * info->num_row; + dh::Timer time0; hmat_.Init(&fmat, param.max_bin); + cpu_init_time += time0.elapsedSeconds(); + if (param.debug_verbose) { // Only done once for each training session + LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init " + << time0.elapsedSeconds() << " sec"; + fflush(stdout); + } + time0.reset(); + gmat_.cut = &hmat_; + cpu_init_time += time0.elapsedSeconds(); + if (param.debug_verbose) { // Only done once for each training session + LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut " + << time0.elapsedSeconds() << " sec"; + fflush(stdout); + } + time0.reset(); + gmat_.Init(&fmat); + cpu_init_time += time0.elapsedSeconds(); + if (param.debug_verbose) { // Only done once for each training session + LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() " + << time0.elapsedSeconds() << " sec"; + fflush(stdout); + } + time0.reset(); + + if (param.debug_verbose) { // Only done once for each training session + LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init, gmat_.cut, gmat_.Init " + << cpu_init_time << " sec"; + fflush(stdout); + } + int n_bins = hmat_.row_ptr.back(); int n_features = hmat_.row_ptr.size() - 1; @@ -324,10 +370,8 @@ void GPUHistBuilder::InitData(const std::vector& gpair, if (!param.silent) { const int mb_size = 1048576; - LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << " MB"; + LOG(CONSOLE) << "[GPU Plug-in] Allocated " << ba.size() / mb_size << " MB"; } - - initialised = true; } // copy or init to do every iteration @@ -355,7 +399,20 @@ void GPUHistBuilder::InitData(const std::vector& gpair, dh::synchronize_n_devices(n_devices, dList); + if (!initialised) { + gpu_init_time = time1.elapsedSeconds() - cpu_init_time; + gpu_time = -cpu_init_time; + if (param.debug_verbose) { // Only done once for each training session + LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First Call to InitData() " + << gpu_init_time << " sec"; + fflush(stdout); + } + } + + p_last_fmat_ = &fmat; + + initialised = true; } void GPUHistBuilder::BuildHist(int depth) { @@ -623,7 +680,7 @@ __global__ void find_split_kernel( #define MIN_BLOCK_THREADS 32 #define CHUNK_BLOCK_THREADS 32 // MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due -// to CUDA compatibility 35 and above requirement +// to CUDA capability 35 and above requirement // for Maximum number of threads per block #define MAX_BLOCK_THREADS 1024 @@ -1082,6 +1139,8 @@ bool GPUHistBuilder::UpdatePredictionCache( void GPUHistBuilder::Update(const std::vector& gpair, DMatrix* p_fmat, RegTree* p_tree) { + dh::Timer time0; + this->InitData(gpair, *p_fmat, *p_tree); this->InitFirstNode(gpair); this->ColSampleTree(); @@ -1097,6 +1156,24 @@ void GPUHistBuilder::Update(const std::vector& gpair, int master_device = dList[0]; dh::safe_cuda(cudaSetDevice(master_device)); dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param); + + gpu_time += time0.elapsedSeconds(); + + if (param.debug_verbose) { + LOG(CONSOLE) << "[GPU Plug-in] Cumulative GPU Time excluding initial time " + << (gpu_time - gpu_init_time) + << " sec"; + fflush(stdout); + } + + if (param.debug_verbose) { + LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time " + << cpu_time.elapsedSeconds() << " sec"; + LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time excluding initial time " + << (cpu_time.elapsedSeconds() - cpu_init_time - gpu_time) + << " sec"; + fflush(stdout); + } } } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/src/gpu_hist_builder.cuh b/plugin/updater_gpu/src/gpu_hist_builder.cuh index f80ca2990..72f2cf946 100644 --- a/plugin/updater_gpu/src/gpu_hist_builder.cuh +++ b/plugin/updater_gpu/src/gpu_hist_builder.cuh @@ -122,6 +122,13 @@ class GPUHistBuilder { std::vector streams; std::vector comms; std::vector> find_split_comms; + + double cpu_init_time; + double gpu_init_time; + dh::Timer cpu_time; + double gpu_time; + + }; } // namespace tree } // namespace xgboost diff --git a/plugin/updater_gpu/test/python/test.py b/plugin/updater_gpu/test/python/test.py index 9823063f2..34e3ed339 100644 --- a/plugin/updater_gpu/test/python/test.py +++ b/plugin/updater_gpu/test/python/test.py @@ -32,6 +32,7 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'eta': 1, 'silent': 1, + 'debug_verbose': 0, 'objective': 'binary:logistic', 'eval_metric': 'auc'} ag_param2 = {'max_depth': 2, @@ -39,6 +40,7 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'eta': 1, 'silent': 1, + 'debug_verbose': 0, 'objective': 'binary:logistic', 'eval_metric': 'auc'} ag_res = {} @@ -63,6 +65,7 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'tree_method': 'gpu_exact', 'max_depth': 3, + 'debug_verbose': 0, 'eval_metric': 'auc'} res = {} xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], @@ -80,6 +83,7 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'tree_method': 'gpu_exact', 'max_depth': 2, + 'debug_verbose': 0, 'eval_metric': 'auc'} res = {} xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) @@ -134,6 +138,7 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'eta': 1, 'silent': 1, + 'debug_verbose': 0, 'objective': 'binary:logistic', 'eval_metric': 'auc'} ag_param2 = {'max_depth': max_depth, @@ -141,6 +146,7 @@ class TestGPU(unittest.TestCase): 'tree_method': 'gpu_hist', 'eta': 1, 'silent': 1, + 'debug_verbose': 0, 'n_gpus': 1, 'objective': 'binary:logistic', 'max_bin': max_bin, @@ -150,6 +156,7 @@ class TestGPU(unittest.TestCase): 'tree_method': 'gpu_hist', 'eta': 1, 'silent': 1, + 'debug_verbose': 0, 'n_gpus': n_gpus, 'objective': 'binary:logistic', 'max_bin': max_bin, @@ -187,6 +194,7 @@ class TestGPU(unittest.TestCase): 'max_depth': max_depth, 'n_gpus': 1, 'max_bin': max_bin, + 'debug_verbose': 0, 'eval_metric': 'auc'} res = {} #eprint("digits: grow_gpu_hist updater 1 gpu"); @@ -200,6 +208,7 @@ class TestGPU(unittest.TestCase): 'max_depth': max_depth, 'n_gpus': n_gpus, 'max_bin': max_bin, + 'debug_verbose': 0, 'eval_metric': 'auc'} res2 = {} #eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus)); @@ -223,6 +232,7 @@ class TestGPU(unittest.TestCase): 'max_depth': max_depth, 'n_gpus': n_gpus, 'max_bin': max_bin, + 'debug_verbose': 0, 'eval_metric': 'auc'} res = {} xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) @@ -262,6 +272,7 @@ class TestGPU(unittest.TestCase): 'tree_method': 'gpu_hist', 'max_depth': max_depth, 'n_gpus': n_gpus, + 'debug_verbose': 0, 'eval_metric': 'auc', 'max_bin': max_bin} res = {} @@ -280,6 +291,7 @@ class TestGPU(unittest.TestCase): 'colsample_bytree': 0.5, 'colsample_bylevel': 0.5, 'subsample': 0.5, + 'debug_verbose': 0, 'max_bin': max_bin} res = {} xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) @@ -293,6 +305,7 @@ class TestGPU(unittest.TestCase): 'tree_method': 'gpu_hist', 'max_depth': 2, 'n_gpus': n_gpus, + 'debug_verbose': 0, 'eval_metric': 'auc', 'max_bin': 2} res = {} diff --git a/plugin/updater_gpu/test/python/test_large.py b/plugin/updater_gpu/test/python/test_large.py index 5878363ad..13aa930fc 100644 --- a/plugin/updater_gpu/test/python/test_large.py +++ b/plugin/updater_gpu/test/python/test_large.py @@ -1,6 +1,7 @@ from __future__ import print_function #pylint: skip-file import sys +import time sys.path.append("../../tests/python") import xgboost as xgb import testing as tm @@ -33,11 +34,16 @@ class TestGPU(unittest.TestCase): for rows in rowslist: eprint("Creating train data rows=%d cols=%d" % (rows,cols)) + tmp = time.time() np.random.seed(7) X = np.random.rand(rows, cols) y = np.random.rand(rows) + print("Time to Create Data: %r" % (time.time() - tmp)) + eprint("Starting DMatrix(X,y)") - ag_dtrain = xgb.DMatrix(X,y,nthread=0) + tmp = time.time() + ag_dtrain = xgb.DMatrix(X,y,nthread=40) + print("Time to DMatrix: %r" % (time.time() - tmp)) max_depth=6 max_bin=1024 @@ -48,6 +54,7 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'eta': 1, 'silent': 0, + 'debug_verbose': 5, 'objective': 'binary:logistic', 'eval_metric': 'auc'} ag_paramb = {'max_depth': max_depth, @@ -55,22 +62,25 @@ class TestGPU(unittest.TestCase): 'nthread': 0, 'eta': 1, 'silent': 0, + 'debug_verbose': 5, 'objective': 'binary:logistic', 'eval_metric': 'auc'} ag_param2 = {'max_depth': max_depth, - 'tree_method': 'gpu_hist', - 'nthread': 0, - 'eta': 1, - 'silent': 0, - 'n_gpus': 1, - 'objective': 'binary:logistic', - 'max_bin': max_bin, - 'eval_metric': 'auc'} + 'tree_method': 'gpu_hist', + 'nthread': 0, + 'eta': 1, + 'silent': 0, + 'debug_verbose': 5, + 'n_gpus': 1, + 'objective': 'binary:logistic', + 'max_bin': max_bin, + 'eval_metric': 'auc'} ag_param3 = {'max_depth': max_depth, 'tree_method': 'gpu_hist', 'nthread': 0, 'eta': 1, 'silent': 0, + 'debug_verbose': 5, 'n_gpus': -1, 'objective': 'binary:logistic', 'max_bin': max_bin, @@ -81,16 +91,23 @@ class TestGPU(unittest.TestCase): ag_res3 = {} num_rounds = 1 - - eprint("hist updater") - xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], - evals_result=ag_resb) + tmp = time.time() + #eprint("hist updater") + #xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], + # evals_result=ag_resb) + #print("Time to Train: %s seconds" % (str(time.time() - tmp))) + + tmp = time.time() eprint("gpu_hist updater 1 gpu") xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], evals_result=ag_res2) + print("Time to Train: %s seconds" % (str(time.time() - tmp))) + + tmp = time.time() eprint("gpu_hist updater all gpus") xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], evals_result=ag_res3) + print("Time to Train: %s seconds" % (str(time.time() - tmp)))