[GPU-Plugin] Add throw of asserts and added compute compatibility error check. (#2565)
* [GPU-Plugin] Added compute compatibility error check, added verbose timing
This commit is contained in:
parent
75ea07b847
commit
c1104f7d0a
@ -63,14 +63,22 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
|
|||||||
|
|
||||||
#define gpuErrchk(ans) \
|
#define gpuErrchk(ans) \
|
||||||
{ gpuAssert((ans), __FILE__, __LINE__); }
|
{ gpuAssert((ans), __FILE__, __LINE__); }
|
||||||
|
|
||||||
inline void gpuAssert(cudaError_t code, const char *file, int line,
|
inline void gpuAssert(cudaError_t code, const char *file, int line,
|
||||||
bool abort = true) {
|
bool abort = true) {
|
||||||
if (code != cudaSuccess) {
|
if (code != cudaSuccess) {
|
||||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
||||||
line);
|
line);
|
||||||
if (abort) exit(code);
|
if (abort){
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << file << "(" << line << ")";
|
||||||
|
std::string file_and_line;
|
||||||
|
ss >> file_and_line;
|
||||||
|
throw thrust::system_error(code, thrust::cuda_category(), file_and_line);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
inline int n_visible_devices() {
|
inline int n_visible_devices() {
|
||||||
int n_visgpus = 0;
|
int n_visgpus = 0;
|
||||||
|
|||||||
@ -6,6 +6,8 @@
|
|||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
#include <thrust/sort.h>
|
#include <thrust/sort.h>
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <future>
|
#include <future>
|
||||||
@ -130,7 +132,7 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
|||||||
if (!param.silent) {
|
if (!param.silent) {
|
||||||
size_t free_memory = dh::available_memory(device_idx);
|
size_t free_memory = dh::available_memory(device_idx);
|
||||||
const int mb_size = 1048576;
|
const int mb_size = 1048576;
|
||||||
LOG(CONSOLE) << "Device: [" << device_idx << "] "
|
LOG(CONSOLE) << "[GPU Plug-in] Device: [" << device_idx << "] "
|
||||||
<< dh::device_name(device_idx) << " with "
|
<< dh::device_name(device_idx) << " with "
|
||||||
<< free_memory / mb_size << " MB available device memory.";
|
<< free_memory / mb_size << " MB available device memory.";
|
||||||
}
|
}
|
||||||
@ -139,12 +141,19 @@ void GPUHistBuilder::Init(const TrainParam& param) {
|
|||||||
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||||
DMatrix& fmat, // NOLINT
|
DMatrix& fmat, // NOLINT
|
||||||
const RegTree& tree) {
|
const RegTree& tree) {
|
||||||
|
dh::Timer time1;
|
||||||
// set member num_rows and n_devices for rest of GPUHistBuilder members
|
// set member num_rows and n_devices for rest of GPUHistBuilder members
|
||||||
info = &fmat.info();
|
info = &fmat.info();
|
||||||
num_rows = info->num_row;
|
num_rows = info->num_row;
|
||||||
n_devices = dh::n_devices(param.n_gpus, num_rows);
|
n_devices = dh::n_devices(param.n_gpus, num_rows);
|
||||||
|
|
||||||
if (!initialised) {
|
if (!initialised) {
|
||||||
|
// reset static timers used across iterations
|
||||||
|
cpu_init_time = 0;
|
||||||
|
gpu_init_time = 0;
|
||||||
|
cpu_time.reset();
|
||||||
|
gpu_time = 0;
|
||||||
|
|
||||||
// set dList member
|
// set dList member
|
||||||
dList.resize(n_devices);
|
dList.resize(n_devices);
|
||||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||||
@ -176,7 +185,13 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev));
|
dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev));
|
||||||
// printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev,
|
// printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev,
|
||||||
// prop.pciBusID, prop.name);
|
// prop.pciBusID, prop.name);
|
||||||
fflush(stdout);
|
// cudaDriverGetVersion(&driverVersion);
|
||||||
|
// cudaRuntimeGetVersion(&runtimeVersion);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "CUDA Capability Major/Minor version number: "
|
||||||
|
<< prop.major << "." << prop.minor << " is insufficient. Need >=3.5.";
|
||||||
|
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
|
||||||
|
CHECK(failed == 0) << oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// local find_split group of comms for each case of reduced number of GPUs
|
// local find_split group of comms for each case of reduced number of GPUs
|
||||||
@ -199,9 +214,40 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
"block. Try setting 'tree_method' "
|
"block. Try setting 'tree_method' "
|
||||||
"parameter to 'exact'";
|
"parameter to 'exact'";
|
||||||
is_dense = info->num_nonzero == info->num_col * info->num_row;
|
is_dense = info->num_nonzero == info->num_col * info->num_row;
|
||||||
|
dh::Timer time0;
|
||||||
hmat_.Init(&fmat, param.max_bin);
|
hmat_.Init(&fmat, param.max_bin);
|
||||||
|
cpu_init_time += time0.elapsedSeconds();
|
||||||
|
if (param.debug_verbose) { // Only done once for each training session
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init "
|
||||||
|
<< time0.elapsedSeconds() << " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
time0.reset();
|
||||||
|
|
||||||
gmat_.cut = &hmat_;
|
gmat_.cut = &hmat_;
|
||||||
|
cpu_init_time += time0.elapsedSeconds();
|
||||||
|
if (param.debug_verbose) { // Only done once for each training session
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.cut "
|
||||||
|
<< time0.elapsedSeconds() << " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
time0.reset();
|
||||||
|
|
||||||
gmat_.Init(&fmat);
|
gmat_.Init(&fmat);
|
||||||
|
cpu_init_time += time0.elapsedSeconds();
|
||||||
|
if (param.debug_verbose) { // Only done once for each training session
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for gmat_.Init() "
|
||||||
|
<< time0.elapsedSeconds() << " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
time0.reset();
|
||||||
|
|
||||||
|
if (param.debug_verbose) { // Only done once for each training session
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] CPU Time for hmat_.Init, gmat_.cut, gmat_.Init "
|
||||||
|
<< cpu_init_time << " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
int n_bins = hmat_.row_ptr.back();
|
int n_bins = hmat_.row_ptr.back();
|
||||||
int n_features = hmat_.row_ptr.size() - 1;
|
int n_features = hmat_.row_ptr.size() - 1;
|
||||||
|
|
||||||
@ -324,10 +370,8 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
|
|
||||||
if (!param.silent) {
|
if (!param.silent) {
|
||||||
const int mb_size = 1048576;
|
const int mb_size = 1048576;
|
||||||
LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << " MB";
|
LOG(CONSOLE) << "[GPU Plug-in] Allocated " << ba.size() / mb_size << " MB";
|
||||||
}
|
}
|
||||||
|
|
||||||
initialised = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy or init to do every iteration
|
// copy or init to do every iteration
|
||||||
@ -355,7 +399,20 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
|||||||
|
|
||||||
dh::synchronize_n_devices(n_devices, dList);
|
dh::synchronize_n_devices(n_devices, dList);
|
||||||
|
|
||||||
|
if (!initialised) {
|
||||||
|
gpu_init_time = time1.elapsedSeconds() - cpu_init_time;
|
||||||
|
gpu_time = -cpu_init_time;
|
||||||
|
if (param.debug_verbose) { // Only done once for each training session
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] Time for GPU operations during First Call to InitData() "
|
||||||
|
<< gpu_init_time << " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
p_last_fmat_ = &fmat;
|
p_last_fmat_ = &fmat;
|
||||||
|
|
||||||
|
initialised = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUHistBuilder::BuildHist(int depth) {
|
void GPUHistBuilder::BuildHist(int depth) {
|
||||||
@ -623,7 +680,7 @@ __global__ void find_split_kernel(
|
|||||||
#define MIN_BLOCK_THREADS 32
|
#define MIN_BLOCK_THREADS 32
|
||||||
#define CHUNK_BLOCK_THREADS 32
|
#define CHUNK_BLOCK_THREADS 32
|
||||||
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
|
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
|
||||||
// to CUDA compatibility 35 and above requirement
|
// to CUDA capability 35 and above requirement
|
||||||
// for Maximum number of threads per block
|
// for Maximum number of threads per block
|
||||||
#define MAX_BLOCK_THREADS 1024
|
#define MAX_BLOCK_THREADS 1024
|
||||||
|
|
||||||
@ -1082,6 +1139,8 @@ bool GPUHistBuilder::UpdatePredictionCache(
|
|||||||
|
|
||||||
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||||
DMatrix* p_fmat, RegTree* p_tree) {
|
DMatrix* p_fmat, RegTree* p_tree) {
|
||||||
|
dh::Timer time0;
|
||||||
|
|
||||||
this->InitData(gpair, *p_fmat, *p_tree);
|
this->InitData(gpair, *p_fmat, *p_tree);
|
||||||
this->InitFirstNode(gpair);
|
this->InitFirstNode(gpair);
|
||||||
this->ColSampleTree();
|
this->ColSampleTree();
|
||||||
@ -1097,6 +1156,24 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
|||||||
int master_device = dList[0];
|
int master_device = dList[0];
|
||||||
dh::safe_cuda(cudaSetDevice(master_device));
|
dh::safe_cuda(cudaSetDevice(master_device));
|
||||||
dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param);
|
dense2sparse_tree(p_tree, nodes[0].tbegin(), nodes[0].tend(), param);
|
||||||
|
|
||||||
|
gpu_time += time0.elapsedSeconds();
|
||||||
|
|
||||||
|
if (param.debug_verbose) {
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] Cumulative GPU Time excluding initial time "
|
||||||
|
<< (gpu_time - gpu_init_time)
|
||||||
|
<< " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (param.debug_verbose) {
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time "
|
||||||
|
<< cpu_time.elapsedSeconds() << " sec";
|
||||||
|
LOG(CONSOLE) << "[GPU Plug-in] Cumulative CPU Time excluding initial time "
|
||||||
|
<< (cpu_time.elapsedSeconds() - cpu_init_time - gpu_time)
|
||||||
|
<< " sec";
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -122,6 +122,13 @@ class GPUHistBuilder {
|
|||||||
std::vector<cudaStream_t *> streams;
|
std::vector<cudaStream_t *> streams;
|
||||||
std::vector<ncclComm_t> comms;
|
std::vector<ncclComm_t> comms;
|
||||||
std::vector<std::vector<ncclComm_t>> find_split_comms;
|
std::vector<std::vector<ncclComm_t>> find_split_comms;
|
||||||
|
|
||||||
|
double cpu_init_time;
|
||||||
|
double gpu_init_time;
|
||||||
|
dh::Timer cpu_time;
|
||||||
|
double gpu_time;
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 1,
|
'silent': 1,
|
||||||
|
'debug_verbose': 0,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
ag_param2 = {'max_depth': 2,
|
ag_param2 = {'max_depth': 2,
|
||||||
@ -39,6 +40,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 1,
|
'silent': 1,
|
||||||
|
'debug_verbose': 0,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
ag_res = {}
|
ag_res = {}
|
||||||
@ -63,6 +65,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'tree_method': 'gpu_exact',
|
'tree_method': 'gpu_exact',
|
||||||
'max_depth': 3,
|
'max_depth': 3,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
res = {}
|
res = {}
|
||||||
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train'), (dtest, 'test')],
|
||||||
@ -80,6 +83,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'tree_method': 'gpu_exact',
|
'tree_method': 'gpu_exact',
|
||||||
'max_depth': 2,
|
'max_depth': 2,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
res = {}
|
res = {}
|
||||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||||
@ -134,6 +138,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 1,
|
'silent': 1,
|
||||||
|
'debug_verbose': 0,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
ag_param2 = {'max_depth': max_depth,
|
ag_param2 = {'max_depth': max_depth,
|
||||||
@ -141,6 +146,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'tree_method': 'gpu_hist',
|
'tree_method': 'gpu_hist',
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 1,
|
'silent': 1,
|
||||||
|
'debug_verbose': 0,
|
||||||
'n_gpus': 1,
|
'n_gpus': 1,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
@ -150,6 +156,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'tree_method': 'gpu_hist',
|
'tree_method': 'gpu_hist',
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 1,
|
'silent': 1,
|
||||||
|
'debug_verbose': 0,
|
||||||
'n_gpus': n_gpus,
|
'n_gpus': n_gpus,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
@ -187,6 +194,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'max_depth': max_depth,
|
'max_depth': max_depth,
|
||||||
'n_gpus': 1,
|
'n_gpus': 1,
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
res = {}
|
res = {}
|
||||||
#eprint("digits: grow_gpu_hist updater 1 gpu");
|
#eprint("digits: grow_gpu_hist updater 1 gpu");
|
||||||
@ -200,6 +208,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'max_depth': max_depth,
|
'max_depth': max_depth,
|
||||||
'n_gpus': n_gpus,
|
'n_gpus': n_gpus,
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
res2 = {}
|
res2 = {}
|
||||||
#eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
|
#eprint("digits: grow_gpu_hist updater %d gpus" % (n_gpus));
|
||||||
@ -223,6 +232,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'max_depth': max_depth,
|
'max_depth': max_depth,
|
||||||
'n_gpus': n_gpus,
|
'n_gpus': n_gpus,
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
res = {}
|
res = {}
|
||||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||||
@ -262,6 +272,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'tree_method': 'gpu_hist',
|
'tree_method': 'gpu_hist',
|
||||||
'max_depth': max_depth,
|
'max_depth': max_depth,
|
||||||
'n_gpus': n_gpus,
|
'n_gpus': n_gpus,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc',
|
'eval_metric': 'auc',
|
||||||
'max_bin': max_bin}
|
'max_bin': max_bin}
|
||||||
res = {}
|
res = {}
|
||||||
@ -280,6 +291,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'colsample_bytree': 0.5,
|
'colsample_bytree': 0.5,
|
||||||
'colsample_bylevel': 0.5,
|
'colsample_bylevel': 0.5,
|
||||||
'subsample': 0.5,
|
'subsample': 0.5,
|
||||||
|
'debug_verbose': 0,
|
||||||
'max_bin': max_bin}
|
'max_bin': max_bin}
|
||||||
res = {}
|
res = {}
|
||||||
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
xgb.train(param, dtrain2, num_rounds, [(dtrain2, 'train')], evals_result=res)
|
||||||
@ -293,6 +305,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'tree_method': 'gpu_hist',
|
'tree_method': 'gpu_hist',
|
||||||
'max_depth': 2,
|
'max_depth': 2,
|
||||||
'n_gpus': n_gpus,
|
'n_gpus': n_gpus,
|
||||||
|
'debug_verbose': 0,
|
||||||
'eval_metric': 'auc',
|
'eval_metric': 'auc',
|
||||||
'max_bin': 2}
|
'max_bin': 2}
|
||||||
res = {}
|
res = {}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
#pylint: skip-file
|
#pylint: skip-file
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
sys.path.append("../../tests/python")
|
sys.path.append("../../tests/python")
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import testing as tm
|
import testing as tm
|
||||||
@ -33,11 +34,16 @@ class TestGPU(unittest.TestCase):
|
|||||||
for rows in rowslist:
|
for rows in rowslist:
|
||||||
|
|
||||||
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
||||||
|
tmp = time.time()
|
||||||
np.random.seed(7)
|
np.random.seed(7)
|
||||||
X = np.random.rand(rows, cols)
|
X = np.random.rand(rows, cols)
|
||||||
y = np.random.rand(rows)
|
y = np.random.rand(rows)
|
||||||
|
print("Time to Create Data: %r" % (time.time() - tmp))
|
||||||
|
|
||||||
eprint("Starting DMatrix(X,y)")
|
eprint("Starting DMatrix(X,y)")
|
||||||
ag_dtrain = xgb.DMatrix(X,y,nthread=0)
|
tmp = time.time()
|
||||||
|
ag_dtrain = xgb.DMatrix(X,y,nthread=40)
|
||||||
|
print("Time to DMatrix: %r" % (time.time() - tmp))
|
||||||
|
|
||||||
max_depth=6
|
max_depth=6
|
||||||
max_bin=1024
|
max_bin=1024
|
||||||
@ -48,6 +54,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 0,
|
'silent': 0,
|
||||||
|
'debug_verbose': 5,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
ag_paramb = {'max_depth': max_depth,
|
ag_paramb = {'max_depth': max_depth,
|
||||||
@ -55,6 +62,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 0,
|
'silent': 0,
|
||||||
|
'debug_verbose': 5,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'eval_metric': 'auc'}
|
'eval_metric': 'auc'}
|
||||||
ag_param2 = {'max_depth': max_depth,
|
ag_param2 = {'max_depth': max_depth,
|
||||||
@ -62,6 +70,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 0,
|
'silent': 0,
|
||||||
|
'debug_verbose': 5,
|
||||||
'n_gpus': 1,
|
'n_gpus': 1,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
@ -71,6 +80,7 @@ class TestGPU(unittest.TestCase):
|
|||||||
'nthread': 0,
|
'nthread': 0,
|
||||||
'eta': 1,
|
'eta': 1,
|
||||||
'silent': 0,
|
'silent': 0,
|
||||||
|
'debug_verbose': 5,
|
||||||
'n_gpus': -1,
|
'n_gpus': -1,
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'max_bin': max_bin,
|
'max_bin': max_bin,
|
||||||
@ -81,16 +91,23 @@ class TestGPU(unittest.TestCase):
|
|||||||
ag_res3 = {}
|
ag_res3 = {}
|
||||||
|
|
||||||
num_rounds = 1
|
num_rounds = 1
|
||||||
|
tmp = time.time()
|
||||||
|
#eprint("hist updater")
|
||||||
|
#xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||||
|
# evals_result=ag_resb)
|
||||||
|
#print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
eprint("hist updater")
|
tmp = time.time()
|
||||||
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
|
||||||
evals_result=ag_resb)
|
|
||||||
eprint("gpu_hist updater 1 gpu")
|
eprint("gpu_hist updater 1 gpu")
|
||||||
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||||
evals_result=ag_res2)
|
evals_result=ag_res2)
|
||||||
|
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
|
tmp = time.time()
|
||||||
eprint("gpu_hist updater all gpus")
|
eprint("gpu_hist updater all gpus")
|
||||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||||
evals_result=ag_res3)
|
evals_result=ag_res3)
|
||||||
|
print("Time to Train: %s seconds" % (str(time.time() - tmp)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user