[GPU-Plugin] Add throw of asserts and added compute compatibility error check. (#2565)
* [GPU-Plugin] Added compute compatibility error check, added verbose timing
This commit is contained in:
parent
75ea07b847
commit
c1104f7d0a
@ -63,15 +63,23 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
|
||||
|
||||
#define gpuErrchk(ans) \
|
||||
{ gpuAssert((ans), __FILE__, __LINE__); }
|
||||
|
||||
inline void gpuAssert(cudaError_t code, const char *file, int line,
|
||||
bool abort = true) {
|
||||
if (code != cudaSuccess) {
|
||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
||||
line);
|
||||
if (abort) exit(code);
|
||||
if (abort){
|
||||
std::stringstream ss;
|
||||
ss << file << "(" << line << ")";
|
||||
std::string file_and_line;
|
||||
ss >> file_and_line;
|
||||
throw thrust::system_error(code, thrust::cuda_category(), file_and_line);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline int n_visible_devices() {
|
||||
int n_visgpus = 0;
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
@ -130,7 +132,7 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
||||
if (!param.silent) {
|
||||
size_t free_memory = dh::available_memory(device_idx);
|
||||
const int mb_size = 1048576;
|
||||
LOG(CONSOLE) << "Device: [" << device_idx << "] "
|
||||
LOG(CONSOLE) << "[GPU Plug-in] Device: [" << device_idx << "] "
|
||||
<< dh::device_name(device_idx) << " with "
|
||||
<< free_memory / mb_size << " MB available device memory.";
|
||||
}
|
||||
@ -139,12 +141,19 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
||||
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
DMatrix& fmat, // NOLINT
|
||||
const RegTree& tree) {
|
||||
dh::Timer time1;
|
||||
// set member num_rows and n_devices for rest of GPUHistBuilder members
|
||||
info = &fmat.info();
|
||||
num_rows = info->num_row;
|
||||
n_devices = dh::n_devices(param.n_gpus, num_rows);
|
||||
|
||||
if (!initialised) {
|
||||
// reset static timers used across iterations
|
||||
cpu_init_time = 0;
|
||||
gpu_init_time = 0;
|
||||
cpu_time.reset();
|
||||
gpu_time = 0;
|
||||
|
||||
// set dList member
|
||||
dList.resize(n_devices);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
@ -176,7 +185,13 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev));
|
||||
// printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev,
|
||||
// prop.pciBusID, prop.name);
|
||||
fflush(stdout);
|
||||
// cudaDriverGetVersion(&driverVersion);
|
||||
// cudaRuntimeGetVersion(&runtimeVersion);
|
||||
std::ostringstream oss;
|
||||
oss << "CUDA Capability Major/Minor version number: "
|
||||
<< prop.major << "." << prop.minor << " is insufficient. Need >=3.5.";
|
||||
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
|
||||
CHECK(failed == 0) << oss.str();
|
||||
}
|
||||
|
||||
// local find_split group of comms for each case of reduced number of GPUs
|
||||
@ -199,9 +214,40 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
"block. Try setting 'tree_method' "
|
||||
"parameter to 'exact'";
|
||||
is_dense = info->num_nonzero == info->num_col * info->num_row;
|
||||
dh::Timer time0;
|
||||
hmat_.Init(&fmat, param.max_bin);
|
||||
cpu_init_time += time0.elapsedSeconds();
|
||||
if (param.debug_verbose) { // Only done once for each training session
|
||||
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init "
|
||||
<< time0.elapsedSeconds() << " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
time0.reset();
|
||||
|
||||
gmat_.cut = &hmat_;
|
||||
cpu_init_time += time0.elapsedSeconds();
|
||||
if (param.debug_verbose) { // Only done once for each training session
|
||||
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut "
|
||||
<< time0.elapsedSeconds() << " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
time0.reset();
|
||||
|
||||
gmat_.Init(&fmat);
|
||||
cpu_init_time += time0.elapsedSeconds();
|
||||
if (param.debug_verbose) { // Only done once for each training session
|
||||
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() "
|
||||
<< time0.elapsedSeconds() << " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
time0.reset();
|
||||
|
||||
if (param.debug_verbose) { // Only done once for each training session
|
||||
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init, gmat_.cut, gmat_.Init "
|
||||
<< cpu_init_time << " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
int n_bins = hmat_.row_ptr.back();
|
||||
int n_features = hmat_.row_ptr.size() - 1;
|
||||
|
||||
@ -324,10 +370,8 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
|
||||
if (!param.silent) {
|
||||
const int mb_size = 1048576;
|
||||
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << " MB";
|
||||
LOG(CONSOLE) << "[GPU Plug-in] Allocated " << ba.size() / mb_size << " MB";
|
||||
}
|
||||
|
||||
initialised = true;
|
||||
}
|
||||
|
||||
// copy or init to do every iteration
|
||||
@ -355,7 +399,20 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
|
||||
dh::synchronize_n_devices(n_devices, dList);
|
||||
|
||||
if (!initialised) {
|
||||
gpu_init_time = time1.elapsedSeconds() - cpu_init_time;
|
||||
gpu_time = -cpu_init_time;
|
||||
if (param.debug_verbose) { // Only done once for each training session
|
||||
LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First Call to InitData() "
|
||||
<< gpu_init_time << " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
p_last_fmat_ = &fmat;
|
||||
|
||||
initialised = true;
|
||||
}
|
||||
|
||||
void GPUHistBuilder::BuildHist(int depth) {
|
||||
@ -623,7 +680,7 @@ __global__ void find_split_kernel(
|
||||
#define MIN_BLOCK_THREADS 32
|
||||
#define CHUNK_BLOCK_THREADS 32
|
||||
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
|
||||
// to CUDA compatibility 35 and above requirement
|
||||
// to CUDA capability 35 and above requirement
|
||||
// for Maximum number of threads per block
|
||||
#define MAX_BLOCK_THREADS 1024
|
||||
|
||||
@ -1082,6 +1139,8 @@ bool GPUHistBuilder::UpdatePredictionCache(
|
||||
|
||||
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
DMatrix* p_fmat, RegTree* p_tree) {
|
||||
dh::Timer time0;
|
||||
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitFirstNode(gpair);
|
||||
this->ColSampleTree();
|
||||
@ -1097,6 +1156,24 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
int master_device = dList[0];
|
||||
dh::safe_cuda(cudaSetDevice(master_device));
|
||||
dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param);
|
||||
|
||||
gpu_time += time0.elapsedSeconds();
|
||||
|
||||
if (param.debug_verbose) {
|
||||
LOG(CONSOLE) << "[GPU Plug-in] Cumulative GPU Time excluding initial time "
|
||||
<< (gpu_time - gpu_init_time)
|
||||
<< " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
if (param.debug_verbose) {
|
||||
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time "
|
||||
<< cpu_time.elapsedSeconds() << " sec";
|
||||
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time excluding initial time "
|
||||
<< (cpu_time.elapsedSeconds() - cpu_init_time - gpu_time)
|
||||
<< " sec";
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -122,6 +122,13 @@ class GPUHistBuilder {
|
||||
std::vector<cudaStream_t *> streams;
|
||||
std::vector<ncclComm_t> comms;
|
||||
std::vector<std::vector<ncclComm_t>> find_split_comms;
|
||||
|
||||
double cpu_init_time;
|
||||
double gpu_init_time;
|
||||
dh::Timer cpu_time;
|
||||
double gpu_time;
|
||||
|
||||
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -32,6 +32,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'debug_verbose': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': 2,
|
||||
@ -39,6 +40,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'debug_verbose': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_res = {}
|
||||
@ -63,6 +65,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_exact',
|
||||
'max_depth': 3,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
||||
@ -80,6 +83,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'tree_method': 'gpu_exact',
|
||||
'max_depth': 2,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
@ -134,6 +138,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'debug_verbose': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
@ -141,6 +146,7 @@ class TestGPU(unittest.TestCase):
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'debug_verbose': 0,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
@ -150,6 +156,7 @@ class TestGPU(unittest.TestCase):
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 1,
|
||||
'debug_verbose': 0,
|
||||
'n_gpus': n_gpus,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
@ -187,6 +194,7 @@ class TestGPU(unittest.TestCase):
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': 1,
|
||||
'max_bin': max_bin,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
#eprint("digits: grow_gpu_hist updater 1 gpu");
|
||||
@ -200,6 +208,7 @@ class TestGPU(unittest.TestCase):
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
'max_bin': max_bin,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res2 = {}
|
||||
#eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
|
||||
@ -223,6 +232,7 @@ class TestGPU(unittest.TestCase):
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
'max_bin': max_bin,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc'}
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
@ -262,6 +272,7 @@ class TestGPU(unittest.TestCase):
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': max_depth,
|
||||
'n_gpus': n_gpus,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc',
|
||||
'max_bin': max_bin}
|
||||
res = {}
|
||||
@ -280,6 +291,7 @@ class TestGPU(unittest.TestCase):
|
||||
'colsample_bytree': 0.5,
|
||||
'colsample_bylevel': 0.5,
|
||||
'subsample': 0.5,
|
||||
'debug_verbose': 0,
|
||||
'max_bin': max_bin}
|
||||
res = {}
|
||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||
@ -293,6 +305,7 @@ class TestGPU(unittest.TestCase):
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': 2,
|
||||
'n_gpus': n_gpus,
|
||||
'debug_verbose': 0,
|
||||
'eval_metric': 'auc',
|
||||
'max_bin': 2}
|
||||
res = {}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import print_function
|
||||
#pylint: skip-file
|
||||
import sys
|
||||
import time
|
||||
sys.path.append("../../tests/python")
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
@ -33,11 +34,16 @@ class TestGPU(unittest.TestCase):
|
||||
for rows in rowslist:
|
||||
|
||||
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
||||
tmp = time.time()
|
||||
np.random.seed(7)
|
||||
X = np.random.rand(rows, cols)
|
||||
y = np.random.rand(rows)
|
||||
print("Time to Create Data: %r" % (time.time() - tmp))
|
||||
|
||||
eprint("Starting DMatrix(X,y)")
|
||||
ag_dtrain = xgb.DMatrix(X,y,nthread=0)
|
||||
tmp = time.time()
|
||||
ag_dtrain = xgb.DMatrix(X,y,nthread=40)
|
||||
print("Time to DMatrix: %r" % (time.time() - tmp))
|
||||
|
||||
max_depth=6
|
||||
max_bin=1024
|
||||
@ -48,6 +54,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_paramb = {'max_depth': max_depth,
|
||||
@ -55,6 +62,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
@ -62,6 +70,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
@ -71,6 +80,7 @@ class TestGPU(unittest.TestCase):
|
||||
'nthread': 0,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'debug_verbose': 5,
|
||||
'n_gpus': -1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
@ -81,16 +91,23 @@ class TestGPU(unittest.TestCase):
|
||||
ag_res3 = {}
|
||||
|
||||
num_rounds = 1
|
||||
tmp = time.time()
|
||||
#eprint("hist updater")
|
||||
#xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
# evals_result=ag_resb)
|
||||
#print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
eprint("hist updater")
|
||||
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
evals_result=ag_resb)
|
||||
tmp = time.time()
|
||||
eprint("gpu_hist updater 1 gpu")
|
||||
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
evals_result=ag_res2)
|
||||
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
tmp = time.time()
|
||||
eprint("gpu_hist updater all gpus")
|
||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||
evals_result=ag_res3)
|
||||
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user