[GPU-Plugin] Fix gpu_hist to allow matrices with more than just 2^{32} elements. Also fixed CPU hist algorithm. (#2518)
This commit is contained in:
committed by
Rory Mitchell
parent
c85bf9859e
commit
ca7fc9fda3
@@ -441,7 +441,7 @@ class bulk_allocator {
|
||||
|
||||
public:
|
||||
~bulk_allocator() {
|
||||
for (int i = 0; i < d_ptr.size(); i++) {
|
||||
for (size_t i = 0; i < d_ptr.size(); i++) {
|
||||
if (!(d_ptr[i] == nullptr)) {
|
||||
safe_cuda(cudaSetDevice(_device_idx[i]));
|
||||
safe_cuda(cudaFree(d_ptr[i]));
|
||||
@@ -522,7 +522,7 @@ inline size_t available_memory(int device_idx) {
|
||||
template <typename T>
|
||||
void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
||||
thrust::host_vector<T> h = v;
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
@@ -531,7 +531,7 @@ void print(const thrust::device_vector<T> &v, size_t max_items = 10) {
|
||||
template <typename T>
|
||||
void print(const dvec<T> &v, size_t max_items = 10) {
|
||||
std::vector<T> h = v.as_vector();
|
||||
for (int i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
for (size_t i = 0; i < std::min(max_items, h.size()); i++) {
|
||||
std::cout << " " << h[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
@@ -539,10 +539,10 @@ void print(const dvec<T> &v, size_t max_items = 10) {
|
||||
|
||||
template <typename T>
|
||||
void print(char *label, const thrust::device_vector<T> &v,
|
||||
const char *format = "%d ", int max = 10) {
|
||||
const char *format = "%d ", size_t max = 10) {
|
||||
thrust::host_vector<T> h_v = v;
|
||||
std::cout << label << ":\n";
|
||||
for (int i = 0; i < std::min(static_cast<int>(h_v.size()), max); i++) {
|
||||
for (size_t i = 0; i < std::min(static_cast<size_t>(h_v.size()), max); i++) {
|
||||
printf(format, h_v[i]);
|
||||
}
|
||||
std::cout << "\n";
|
||||
@@ -593,6 +593,7 @@ __global__ void launch_n_kernel(int device_idx, size_t begin, size_t end,
|
||||
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
|
||||
inline void launch_n(int device_idx, size_t n, L lambda) {
|
||||
safe_cuda(cudaSetDevice(device_idx));
|
||||
// TODO: Template on n so GRID_SIZE always fits into int.
|
||||
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
||||
#if defined(__CUDACC__)
|
||||
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
|
||||
@@ -607,6 +608,7 @@ inline void multi_launch_n(size_t n, int n_devices, L lambda) {
|
||||
CHECK_LE(n_devices, n_visible_devices()) << "Number of devices requested "
|
||||
"needs to be less than equal to "
|
||||
"number of visible devices.";
|
||||
// TODO: Template on n so GRID_SIZE always fits into int.
|
||||
const int GRID_SIZE = div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS);
|
||||
#if defined(__CUDACC__)
|
||||
n_devices = n_devices > n ? n : n_devices;
|
||||
|
||||
@@ -19,8 +19,8 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
||||
bst_uint element_begin, bst_uint element_end,
|
||||
bst_uint row_begin, bst_uint row_end, int n_bins) {
|
||||
bst_ulong element_begin, bst_ulong element_end,
|
||||
bst_ulong row_begin, bst_ulong row_end, int n_bins) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
CHECK(gidx_buffer.size()) << "gidx_buffer must be externally allocated";
|
||||
CHECK_EQ(row_ptr.size(), (row_end - row_begin) + 1)
|
||||
@@ -31,15 +31,15 @@ void DeviceGMat::Init(int device_idx, const common::GHistIndexMatrix& gmat,
|
||||
cbw.Write(host_buffer.data(), gmat.index.begin() + element_begin,
|
||||
gmat.index.begin() + element_end);
|
||||
gidx_buffer = host_buffer;
|
||||
gidx = common::CompressedIterator<int>(gidx_buffer.data(), n_bins);
|
||||
gidx = common::CompressedIterator<uint32_t>(gidx_buffer.data(), n_bins);
|
||||
|
||||
// row_ptr
|
||||
thrust::copy(gmat.row_ptr.data() + row_begin,
|
||||
gmat.row_ptr.data() + row_end + 1, row_ptr.tbegin());
|
||||
// normalise row_ptr
|
||||
bst_uint start = gmat.row_ptr[row_begin];
|
||||
size_t start = gmat.row_ptr[row_begin];
|
||||
thrust::transform(row_ptr.tbegin(), row_ptr.tend(), row_ptr.tbegin(),
|
||||
[=] __device__(int val) { return val - start; });
|
||||
[=] __device__(size_t val) { return val - start; });
|
||||
}
|
||||
|
||||
void DeviceHist::Init(int n_bins_in) {
|
||||
@@ -193,7 +193,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
device_row_segments.push_back(0);
|
||||
device_element_segments.push_back(0);
|
||||
bst_uint offset = 0;
|
||||
size_t shard_size = std::ceil(static_cast<double>(num_rows) / n_devices);
|
||||
bst_uint shard_size = std::ceil(static_cast<double>(num_rows) / n_devices);
|
||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||
int device_idx = dList[d_idx];
|
||||
offset += shard_size;
|
||||
@@ -261,7 +261,7 @@ void GPUHistBuilder::InitData(const std::vector<bst_gpair>& gpair,
|
||||
int device_idx = dList[d_idx];
|
||||
bst_uint num_rows_segment =
|
||||
device_row_segments[d_idx + 1] - device_row_segments[d_idx];
|
||||
bst_uint num_elements_segment =
|
||||
bst_ulong num_elements_segment =
|
||||
device_element_segments[d_idx + 1] - device_element_segments[d_idx];
|
||||
ba.allocate(
|
||||
device_idx, &(hist_vec[d_idx].data),
|
||||
@@ -360,7 +360,7 @@ void GPUHistBuilder::BuildHist(int depth) {
|
||||
auto hist_builder = hist_vec[d_idx].GetBuilder();
|
||||
dh::TransformLbs(
|
||||
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
|
||||
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
|
||||
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
|
||||
int nidx = d_position[local_ridx]; // OPTMARK: latency
|
||||
if (!is_active(nidx, depth)) return;
|
||||
|
||||
@@ -606,7 +606,11 @@ __global__ void find_split_kernel(
|
||||
}
|
||||
|
||||
#define MIN_BLOCK_THREADS 32
|
||||
#define MAX_BLOCK_THREADS 1024 // hard-coded maximum block size
|
||||
#define CHUNK_BLOCK_THREADS 32
|
||||
// MAX_BLOCK_THREADS of 1024 is hard-coded maximum block size due
|
||||
// to CUDA compatibility 35 and above requirement
|
||||
// for Maximum number of threads per block
|
||||
#define MAX_BLOCK_THREADS 1024
|
||||
|
||||
void GPUHistBuilder::FindSplit(int depth) {
|
||||
// Specialised based on max_bins
|
||||
@@ -622,7 +626,7 @@ void GPUHistBuilder::FindSplitSpecialize(int depth) {
|
||||
if (param.max_bin <= BLOCK_THREADS) {
|
||||
LaunchFindSplit<BLOCK_THREADS>(depth);
|
||||
} else {
|
||||
this->FindSplitSpecialize<BLOCK_THREADS + 32>(depth);
|
||||
this->FindSplitSpecialize<BLOCK_THREADS + CHUNK_BLOCK_THREADS>(depth);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -885,7 +889,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
|
||||
size_t begin = device_row_segments[d_idx];
|
||||
size_t end = device_row_segments[d_idx + 1];
|
||||
|
||||
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) {
|
||||
dh::launch_n(device_idx, end - begin, [=] __device__(size_t local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
@@ -896,7 +900,8 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
|
||||
return;
|
||||
}
|
||||
|
||||
int gidx = d_gidx[local_idx * n_columns + node.split.findex];
|
||||
int gidx = d_gidx[local_idx *
|
||||
static_cast<size_t>(n_columns) + static_cast<size_t>(node.split.findex)];
|
||||
|
||||
float fvalue = d_gidx_fvalue_map[gidx];
|
||||
|
||||
@@ -955,7 +960,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
||||
|
||||
dh::TransformLbs(
|
||||
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
|
||||
row_end - row_begin, [=] __device__(int local_idx, int local_ridx) {
|
||||
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
|
||||
int pos = d_position[local_ridx];
|
||||
if (!is_active(pos, depth)) {
|
||||
return;
|
||||
@@ -1065,25 +1070,13 @@ void GPUHistBuilder::Update(const std::vector<bst_gpair>& gpair,
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitFirstNode(gpair);
|
||||
this->ColSampleTree();
|
||||
// long long int elapsed=0;
|
||||
for (int depth = 0; depth < param.max_depth; depth++) {
|
||||
this->ColSampleLevel();
|
||||
|
||||
// dh::Timer time;
|
||||
this->BuildHist(depth);
|
||||
// elapsed+=time.elapsed();
|
||||
// printf("depth=%d\n",depth);
|
||||
// time.printElapsed("BH Time");
|
||||
|
||||
// dh::Timer timesplit;
|
||||
this->FindSplit(depth);
|
||||
// timesplit.printElapsed("FS Time");
|
||||
|
||||
// dh::Timer timeupdatepos;
|
||||
this->UpdatePosition(depth);
|
||||
// timeupdatepos.printElapsed("UP Time");
|
||||
}
|
||||
// printf("Total BuildHist Time=%lld\n",elapsed);
|
||||
|
||||
// done with multi-GPU, pass back result from master to tree on host
|
||||
int master_device = dList[0];
|
||||
|
||||
@@ -18,10 +18,10 @@ namespace tree {
|
||||
|
||||
struct DeviceGMat {
|
||||
dh::dvec<common::compressed_byte_t> gidx_buffer;
|
||||
common::CompressedIterator<int > gidx;
|
||||
dh::dvec<int> row_ptr;
|
||||
common::CompressedIterator<uint32_t> gidx;
|
||||
dh::dvec<size_t> row_ptr;
|
||||
void Init(int device_idx, const common::GHistIndexMatrix &gmat,
|
||||
bst_uint begin, bst_uint end, bst_uint row_begin, bst_uint row_end,int n_bins);
|
||||
bst_ulong element_begin, bst_ulong element_end, bst_ulong row_begin, bst_ulong row_end,int n_bins);
|
||||
};
|
||||
|
||||
struct HistBuilder {
|
||||
@@ -99,7 +99,7 @@ class GPUHistBuilder {
|
||||
// below vectors are for each devices used
|
||||
std::vector<int> dList;
|
||||
std::vector<int> device_row_segments;
|
||||
std::vector<int> device_element_segments;
|
||||
std::vector<size_t> device_element_segments;
|
||||
|
||||
std::vector<dh::CubMemory> temp_memory;
|
||||
std::vector<DeviceHist> hist_vec;
|
||||
|
||||
134
plugin/updater_gpu/test/python/test_large.py
Normal file
134
plugin/updater_gpu/test/python/test_large.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from __future__ import print_function
|
||||
#pylint: skip-file
|
||||
import sys
|
||||
sys.path.append("../../tests/python")
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
from sklearn.datasets import make_classification
|
||||
def eprint(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs) ; sys.stderr.flush()
|
||||
print(*args, file=sys.stdout, **kwargs) ; sys.stdout.flush()
|
||||
|
||||
eprint("Testing Big Data (this may take a while)")
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
|
||||
cols = 31
|
||||
# reduced to fit onto 1 gpu but still be large
|
||||
rows2 = 5000 # medium
|
||||
#rows2 = 4032 # fake large for testing
|
||||
rows1 = 42360032 # large
|
||||
#rows2 = 152360032 # can do this for multi-gpu test (very large)
|
||||
rowslist = [rows1, rows2]
|
||||
|
||||
|
||||
class TestGPU(unittest.TestCase):
|
||||
def test_large(self):
|
||||
eprint("Starting test for large data")
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
try:
|
||||
from sklearn.model_selection import train_test_split
|
||||
except:
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
|
||||
for rows in rowslist:
|
||||
|
||||
eprint("Creating train data rows=%d cols=%d" % (rows,cols))
|
||||
X, y = make_classification(rows, n_features=cols, random_state=7)
|
||||
rowstest = int(rows*0.2)
|
||||
eprint("Creating test data rows=%d cols=%d" % (rowstest,cols))
|
||||
# note the new random state. if chose same as train random state, exact methods can memorize and do very well on test even for random data, while hist cannot
|
||||
Xtest, ytest = make_classification(rowstest, n_features=cols, random_state=8)
|
||||
|
||||
eprint("Starting DMatrix(X,y)")
|
||||
ag_dtrain = xgb.DMatrix(X,y)
|
||||
eprint("Starting DMatrix(Xtest,ytest)")
|
||||
ag_dtest = xgb.DMatrix(Xtest,ytest)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
max_depth=6
|
||||
max_bin=1024
|
||||
|
||||
# regression test --- hist must be same as exact on all-categorial data
|
||||
ag_param = {'max_depth': max_depth,
|
||||
'tree_method': 'exact',
|
||||
#'nthread': 1,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_paramb = {'max_depth': max_depth,
|
||||
'tree_method': 'hist',
|
||||
#'nthread': 1,
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc'}
|
||||
ag_param2 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'n_gpus': 1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
ag_param3 = {'max_depth': max_depth,
|
||||
'tree_method': 'gpu_hist',
|
||||
'eta': 1,
|
||||
'silent': 0,
|
||||
'n_gpus': -1,
|
||||
'objective': 'binary:logistic',
|
||||
'max_bin': max_bin,
|
||||
'eval_metric': 'auc'}
|
||||
#ag_param4 = {'max_depth': max_depth,
|
||||
# 'tree_method': 'gpu_exact',
|
||||
# 'eta': 1,
|
||||
# 'silent': 0,
|
||||
# 'n_gpus': 1,
|
||||
# 'objective': 'binary:logistic',
|
||||
# 'max_bin': max_bin,
|
||||
# 'eval_metric': 'auc'}
|
||||
ag_res = {}
|
||||
ag_resb = {}
|
||||
ag_res2 = {}
|
||||
ag_res3 = {}
|
||||
#ag_res4 = {}
|
||||
|
||||
num_rounds = 1
|
||||
|
||||
eprint("normal updater")
|
||||
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res)
|
||||
eprint("hist updater")
|
||||
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_resb)
|
||||
eprint("gpu_hist updater 1 gpu")
|
||||
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res2)
|
||||
eprint("gpu_hist updater all gpus")
|
||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=ag_res3)
|
||||
#eprint("gpu_exact updater")
|
||||
#xgb.train(ag_param4, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
# evals_result=ag_res4)
|
||||
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
|
||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001
|
||||
#assert np.fabs(ag_res['train']['auc'][0] - ag_res4['train']['auc'][0])<0.001
|
||||
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_resb['test']['auc'][0])<0.01
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_res2['test']['auc'][0])<0.01
|
||||
assert np.fabs(ag_res['test']['auc'][0] - ag_res3['test']['auc'][0])<0.01
|
||||
#assert np.fabs(ag_res['test']['auc'][0] - ag_res4['test']['auc'][0])<0.01
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user