[GPU-Plugin] Add throw of asserts and added compute compatibility error check. (#2565)

* [GPU-Plugin] Added compute compatibility error check, added verbose timing
This commit is contained in:
PSEUDOTENSOR / Jonathan McKinney 2017-08-09 21:07:07 -07:00 committed by Rory Mitchell
parent 75ea07b847
commit c1104f7d0a
5 changed files with 142 additions and 20 deletions

View File

@ -63,15 +63,23 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
#define gpuErrchk(ans) \ #define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); } { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) { bool abort = true) {
if (code != cudaSuccess) { if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line); line);
if (abort) exit(code); if (abort){
std::stringstream ss;
ss << file << "(" << line << ")";
std::string file_and_line;
ss >> file_and_line;
throw thrust::system_error(code, thrust::cuda_category(), file_and_line);
}
} }
} }
inline int n_visible_devices() { inline int n_visible_devices() {
int n_visgpus = 0; int n_visgpus = 0;

View File

@ -6,6 +6,8 @@
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <string>
#include <sstream>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <future> #include <future>
@ -130,7 +132,7 @@ void GPUHistBuilder::Init(const TrainParam& param) {
if (!param.silent) { if (!param.silent) {
size_t free_memory = dh::available_memory(device_idx); size_t free_memory = dh::available_memory(device_idx);
const int mb_size = 1048576; const int mb_size = 1048576;
LOG(CONSOLE) << "Device: [" << device_idx << "] " LOG(CONSOLE) << "[GPU Plug-in] Device: [" << device_idx << "] "
<< dh::device_name(device_idx) << " with " << dh::device_name(device_idx) << " with "
<< free_memory / mb_size << " MB available device memory."; << free_memory / mb_size << " MB available device memory.";
} }
@ -139,12 +141,19 @@ void GPUHistBuilder::Init(const TrainParam& param) {
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair, void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
DMatrix& fmat, // NOLINT DMatrix& fmat, // NOLINT
const RegTree& tree) { const RegTree& tree) {
dh::Timer time1;
// set member num_rows and n_devices for rest of GPUHistBuilder members // set member num_rows and n_devices for rest of GPUHistBuilder members
info = &fmat.info(); info = &fmat.info();
num_rows = info->num_row; num_rows = info->num_row;
n_devices = dh::n_devices(param.n_gpus, num_rows); n_devices = dh::n_devices(param.n_gpus, num_rows);
if (!initialised) { if (!initialised) {
// reset static timers used across iterations
cpu_init_time = 0;
gpu_init_time = 0;
cpu_time.reset();
gpu_time = 0;
// set dList member // set dList member
dList.resize(n_devices); dList.resize(n_devices);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) { for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
@ -176,7 +185,13 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev)); dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev));
// printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev, // printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev,
// prop.pciBusID, prop.name); // prop.pciBusID, prop.name);
fflush(stdout); // cudaDriverGetVersion(&driverVersion);
// cudaRuntimeGetVersion(&runtimeVersion);
std::ostringstream oss;
oss << "CUDA Capability Major/Minor version number: "
<< prop.major << "." << prop.minor << " is insufficient. Need >=3.5.";
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
CHECK(failed == 0) << oss.str();
} }
// local find_split group of comms for each case of reduced number of GPUs // local find_split group of comms for each case of reduced number of GPUs
@ -199,9 +214,40 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
"block. Try setting 'tree_method' " "block. Try setting 'tree_method' "
"parameter to 'exact'"; "parameter to 'exact'";
is_dense = info->num_nonzero == info->num_col * info->num_row; is_dense = info->num_nonzero == info->num_col * info->num_row;
dh::Timer time0;
hmat_.Init(&fmat, param.max_bin); hmat_.Init(&fmat, param.max_bin);
cpu_init_time += time0.elapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init "
<< time0.elapsedSeconds() << " sec";
fflush(stdout);
}
time0.reset();
gmat_.cut = &hmat_; gmat_.cut = &hmat_;
cpu_init_time += time0.elapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut "
<< time0.elapsedSeconds() << " sec";
fflush(stdout);
}
time0.reset();
gmat_.Init(&fmat); gmat_.Init(&fmat);
cpu_init_time += time0.elapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() "
<< time0.elapsedSeconds() << " sec";
fflush(stdout);
}
time0.reset();
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init, gmat_.cut, gmat_.Init "
<< cpu_init_time << " sec";
fflush(stdout);
}
int n_bins = hmat_.row_ptr.back(); int n_bins = hmat_.row_ptr.back();
int n_features = hmat_.row_ptr.size() - 1; int n_features = hmat_.row_ptr.size() - 1;
@ -324,10 +370,8 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
if (!param.silent) { if (!param.silent) {
const int mb_size = 1048576; const int mb_size = 1048576;
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << " MB"; LOG(CONSOLE) << "[GPU Plug-in] Allocated " << ba.size() / mb_size << " MB";
} }
initialised = true;
} }
// copy or init to do every iteration // copy or init to do every iteration
@ -355,7 +399,20 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
dh::synchronize_n_devices(n_devices, dList); dh::synchronize_n_devices(n_devices, dList);
if (!initialised) {
gpu_init_time = time1.elapsedSeconds() - cpu_init_time;
gpu_time = -cpu_init_time;
if (param.debug_verbose) { // Only done once for each training session
LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First Call to InitData() "
<< gpu_init_time << " sec";
fflush(stdout);
}
}
p_last_fmat_ = &fmat; p_last_fmat_ = &fmat;
initialised = true;
} }
void GPUHistBuilder::BuildHist(int depth) { void GPUHistBuilder::BuildHist(int depth) {
@ -623,7 +680,7 @@ __global__ void find_split_kernel(
#define MIN_BLOCK_THREADS 32 #define MIN_BLOCK_THREADS 32
#define CHUNK_BLOCK_THREADS 32 #define CHUNK_BLOCK_THREADS 32
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due // MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
// to CUDA compatibility 35 and above requirement // to CUDA capability 35 and above requirement
// for Maximum number of threads per block // for Maximum number of threads per block
#define MAX_BLOCK_THREADS 1024 #define MAX_BLOCK_THREADS 1024
@ -1082,6 +1139,8 @@ bool GPUHistBuilder::UpdatePredictionCache(
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair, void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
DMatrix* p_fmat, RegTree* p_tree) { DMatrix* p_fmat, RegTree* p_tree) {
dh::Timer time0;
this->InitData(gpair, *p_fmat, *p_tree); this->InitData(gpair, *p_fmat, *p_tree);
this->InitFirstNode(gpair); this->InitFirstNode(gpair);
this->ColSampleTree(); this->ColSampleTree();
@ -1097,6 +1156,24 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
int master_device = dList[0]; int master_device = dList[0];
dh::safe_cuda(cudaSetDevice(master_device)); dh::safe_cuda(cudaSetDevice(master_device));
dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param); dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param);
gpu_time += time0.elapsedSeconds();
if (param.debug_verbose) {
LOG(CONSOLE) << "[GPU Plug-in] Cumulative GPU Time excluding initial time "
<< (gpu_time - gpu_init_time)
<< " sec";
fflush(stdout);
}
if (param.debug_verbose) {
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time "
<< cpu_time.elapsedSeconds() << " sec";
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time excluding initial time "
<< (cpu_time.elapsedSeconds() - cpu_init_time - gpu_time)
<< " sec";
fflush(stdout);
}
} }
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -122,6 +122,13 @@ class GPUHistBuilder {
std::vector<cudaStream_t *> streams; std::vector<cudaStream_t *> streams;
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
std::vector<std::vector<ncclComm_t>> find_split_comms; std::vector<std::vector<ncclComm_t>> find_split_comms;
double cpu_init_time;
double gpu_init_time;
dh::Timer cpu_time;
double gpu_time;
}; };
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -32,6 +32,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 1, 'silent': 1,
'debug_verbose': 0,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'eval_metric': 'auc'} 'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2, ag_param2 = {'max_depth': 2,
@ -39,6 +40,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 1, 'silent': 1,
'debug_verbose': 0,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'eval_metric': 'auc'} 'eval_metric': 'auc'}
ag_res = {} ag_res = {}
@ -63,6 +65,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'tree_method': 'gpu_exact', 'tree_method': 'gpu_exact',
'max_depth': 3, 'max_depth': 3,
'debug_verbose': 0,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
res = {} res = {}
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')], xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
@ -80,6 +83,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'tree_method': 'gpu_exact', 'tree_method': 'gpu_exact',
'max_depth': 2, 'max_depth': 2,
'debug_verbose': 0,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
res = {} res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
@ -134,6 +138,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 1, 'silent': 1,
'debug_verbose': 0,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'eval_metric': 'auc'} 'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth, ag_param2 = {'max_depth': max_depth,
@ -141,6 +146,7 @@ class TestGPU(unittest.TestCase):
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'eta': 1, 'eta': 1,
'silent': 1, 'silent': 1,
'debug_verbose': 0,
'n_gpus': 1, 'n_gpus': 1,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'max_bin': max_bin, 'max_bin': max_bin,
@ -150,6 +156,7 @@ class TestGPU(unittest.TestCase):
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'eta': 1, 'eta': 1,
'silent': 1, 'silent': 1,
'debug_verbose': 0,
'n_gpus': n_gpus, 'n_gpus': n_gpus,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'max_bin': max_bin, 'max_bin': max_bin,
@ -187,6 +194,7 @@ class TestGPU(unittest.TestCase):
'max_depth': max_depth, 'max_depth': max_depth,
'n_gpus': 1, 'n_gpus': 1,
'max_bin': max_bin, 'max_bin': max_bin,
'debug_verbose': 0,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
res = {} res = {}
#eprint("digits: grow_gpu_hist updater 1 gpu"); #eprint("digits: grow_gpu_hist updater 1 gpu");
@ -200,6 +208,7 @@ class TestGPU(unittest.TestCase):
'max_depth': max_depth, 'max_depth': max_depth,
'n_gpus': n_gpus, 'n_gpus': n_gpus,
'max_bin': max_bin, 'max_bin': max_bin,
'debug_verbose': 0,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
res2 = {} res2 = {}
#eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus)); #eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
@ -223,6 +232,7 @@ class TestGPU(unittest.TestCase):
'max_depth': max_depth, 'max_depth': max_depth,
'n_gpus': n_gpus, 'n_gpus': n_gpus,
'max_bin': max_bin, 'max_bin': max_bin,
'debug_verbose': 0,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
res = {} res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
@ -262,6 +272,7 @@ class TestGPU(unittest.TestCase):
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'max_depth': max_depth, 'max_depth': max_depth,
'n_gpus': n_gpus, 'n_gpus': n_gpus,
'debug_verbose': 0,
'eval_metric': 'auc', 'eval_metric': 'auc',
'max_bin': max_bin} 'max_bin': max_bin}
res = {} res = {}
@ -280,6 +291,7 @@ class TestGPU(unittest.TestCase):
'colsample_bytree': 0.5, 'colsample_bytree': 0.5,
'colsample_bylevel': 0.5, 'colsample_bylevel': 0.5,
'subsample': 0.5, 'subsample': 0.5,
'debug_verbose': 0,
'max_bin': max_bin} 'max_bin': max_bin}
res = {} res = {}
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res) xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
@ -293,6 +305,7 @@ class TestGPU(unittest.TestCase):
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'max_depth': 2, 'max_depth': 2,
'n_gpus': n_gpus, 'n_gpus': n_gpus,
'debug_verbose': 0,
'eval_metric': 'auc', 'eval_metric': 'auc',
'max_bin': 2} 'max_bin': 2}
res = {} res = {}

View File

@ -1,6 +1,7 @@
from __future__ import print_function from __future__ import print_function
#pylint: skip-file #pylint: skip-file
import sys import sys
import time
sys.path.append("../../tests/python") sys.path.append("../../tests/python")
import xgboost as xgb import xgboost as xgb
import testing as tm import testing as tm
@ -33,11 +34,16 @@ class TestGPU(unittest.TestCase):
for rows in rowslist: for rows in rowslist:
eprint("Creating train data rows=%d cols=%d" % (rows,cols)) eprint("Creating train data rows=%d cols=%d" % (rows,cols))
tmp = time.time()
np.random.seed(7) np.random.seed(7)
X = np.random.rand(rows, cols) X = np.random.rand(rows, cols)
y = np.random.rand(rows) y = np.random.rand(rows)
print("Time to Create Data: %r" % (time.time() - tmp))
eprint("Starting DMatrix(X,y)") eprint("Starting DMatrix(X,y)")
ag_dtrain = xgb.DMatrix(X,y,nthread=0) tmp = time.time()
ag_dtrain = xgb.DMatrix(X,y,nthread=40)
print("Time to DMatrix: %r" % (time.time() - tmp))
max_depth=6 max_depth=6
max_bin=1024 max_bin=1024
@ -48,6 +54,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 0, 'silent': 0,
'debug_verbose': 5,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'eval_metric': 'auc'} 'eval_metric': 'auc'}
ag_paramb = {'max_depth': max_depth, ag_paramb = {'max_depth': max_depth,
@ -55,22 +62,25 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 0, 'silent': 0,
'debug_verbose': 5,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'eval_metric': 'auc'} 'eval_metric': 'auc'}
ag_param2 = {'max_depth': max_depth, ag_param2 = {'max_depth': max_depth,
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 0, 'silent': 0,
'n_gpus': 1, 'debug_verbose': 5,
'objective': 'binary:logistic', 'n_gpus': 1,
'max_bin': max_bin, 'objective': 'binary:logistic',
'eval_metric': 'auc'} 'max_bin': max_bin,
'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth, ag_param3 = {'max_depth': max_depth,
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'silent': 0, 'silent': 0,
'debug_verbose': 5,
'n_gpus': -1, 'n_gpus': -1,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'max_bin': max_bin, 'max_bin': max_bin,
@ -81,16 +91,23 @@ class TestGPU(unittest.TestCase):
ag_res3 = {} ag_res3 = {}
num_rounds = 1 num_rounds = 1
tmp = time.time()
eprint("hist updater") #eprint("hist updater")
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], #xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_resb) # evals_result=ag_resb)
#print("Time to Train: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
eprint("gpu_hist updater 1 gpu") eprint("gpu_hist updater 1 gpu")
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res2) evals_result=ag_res2)
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
eprint("gpu_hist updater all gpus") eprint("gpu_hist updater all gpus")
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res3) evals_result=ag_res3)
print("Time to Train: %s seconds" % (str(time.time() - tmp)))